common.h 8.31 KB
Newer Older
1
2
/**
 *  Copyright (c) 2017-2023 by Contributors
3
 *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
4
5
6
7
8
9
 * @file cuda/common.h
 * @brief Common utilities for CUDA
 */
#ifndef GRAPHBOLT_CUDA_COMMON_H_
#define GRAPHBOLT_CUDA_COMMON_H_

10
#include <ATen/cuda/CUDAEvent.h>
11
#include <c10/cuda/CUDACachingAllocator.h>
12
#include <c10/cuda/CUDAException.h>
13
#include <c10/cuda/CUDAStream.h>
14
#include <cuda_runtime.h>
15
16
17
18
#include <torch/script.h>

#include <memory>
#include <unordered_map>
19
20
21
22

namespace graphbolt {
namespace cuda {

23
24
25
26
/**
 * @brief This class is designed to allocate workspace storage
 * and to get a nonblocking thrust execution policy
 * that uses torch's CUDA memory pool and the current cuda stream:
27
 *
28
29
30
 * cuda::CUDAWorkspaceAllocator allocator;
 * const auto stream = torch::cuda::getDefaultCUDAStream();
 * const auto exec_policy = thrust::cuda::par_nosync(allocator).on(stream);
31
 *
32
 * Now, one can pass exec_policy to thrust functions
33
 *
34
35
 * To get an integer array of size 1000 whose lifetime is managed by unique_ptr,
 * use:
36
 *
37
 * auto int_array = allocator.AllocateStorage<int>(1000);
38
 *
39
40
 * int_array.get() gives the raw pointer.
 */
41
struct CUDAWorkspaceAllocator {
42
43
44
  // Required by thrust to satisfy allocator requirements.
  using value_type = char;

45
  explicit CUDAWorkspaceAllocator() { at::globalContext().lazyInitCUDA(); }
46
47
48

  CUDAWorkspaceAllocator& operator=(const CUDAWorkspaceAllocator&) = default;

49
50
51
  void operator()(void* ptr) const {
    c10::cuda::CUDACachingAllocator::raw_delete(ptr);
  }
52
53
54

  // Required by thrust to satisfy allocator requirements.
  value_type* allocate(std::ptrdiff_t size) const {
55
56
    return reinterpret_cast<value_type*>(
        c10::cuda::CUDACachingAllocator::raw_alloc(size));
57
58
59
60
61
62
63
64
65
66
67
68
69
  }

  // Required by thrust to satisfy allocator requirements.
  void deallocate(value_type* ptr, std::size_t) const { operator()(ptr); }

  template <typename T>
  std::unique_ptr<T, CUDAWorkspaceAllocator> AllocateStorage(
      std::size_t size) const {
    return std::unique_ptr<T, CUDAWorkspaceAllocator>(
        reinterpret_cast<T*>(allocate(sizeof(T) * size)), *this);
  }
};

70
inline auto GetAllocator() { return CUDAWorkspaceAllocator{}; }
71
72

inline auto GetCurrentStream() { return c10::cuda::getCurrentCUDAStream(); }
73

74
75
76
77
78
79
80
81
82
83
template <typename T>
inline bool is_zero(T size) {
  return size == 0;
}

template <>
inline bool is_zero<dim3>(dim3 size) {
  return size.x == 0 || size.y == 0 || size.z == 0;
}

84
85
#define CUDA_CALL(func) C10_CUDA_CHECK((func))

86
87
88
89
90
91
92
93
#define CUDA_KERNEL_CALL(kernel, nblks, nthrs, shmem, ...)          \
  {                                                                 \
    if (!graphbolt::cuda::is_zero((nblks)) &&                       \
        !graphbolt::cuda::is_zero((nthrs))) {                       \
      auto stream = graphbolt::cuda::GetCurrentStream();            \
      (kernel)<<<(nblks), (nthrs), (shmem), stream>>>(__VA_ARGS__); \
      C10_CUDA_KERNEL_LAUNCH_CHECK();                               \
    }                                                               \
94
95
  }

96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#define CUB_CALL(fn, ...)                                                     \
  {                                                                           \
    auto allocator = graphbolt::cuda::GetAllocator();                         \
    auto stream = graphbolt::cuda::GetCurrentStream();                        \
    size_t workspace_size = 0;                                                \
    CUDA_CALL(cub::fn(nullptr, workspace_size, __VA_ARGS__, stream));         \
    auto workspace = allocator.AllocateStorage<char>(workspace_size);         \
    CUDA_CALL(cub::fn(workspace.get(), workspace_size, __VA_ARGS__, stream)); \
  }

#define THRUST_CALL(fn, ...)                                                 \
  [&] {                                                                      \
    auto allocator = graphbolt::cuda::GetAllocator();                        \
    auto stream = graphbolt::cuda::GetCurrentStream();                       \
    const auto exec_policy = thrust::cuda::par_nosync(allocator).on(stream); \
    return thrust::fn(exec_policy, __VA_ARGS__);                             \
  }()

114
115
116
117
118
119
120
121
122
123
124
125
126
/**
 * @brief This class is designed to handle the copy operation of a single
 * scalar_t item from a given CUDA device pointer. Later, if the object is cast
 * into scalar_t, the value can be read.
 *
 * auto num_edges = cuda::CopyScalar(indptr.data_ptr<scalar_t>() +
 *     indptr.size(0) - 1);
 * // Perform many operations here, they will run as normal.
 * // We finally need to read num_edges.
 * auto indices = torch::empty(static_cast<scalar_t>(num_edges));
 */
template <typename scalar_t>
struct CopyScalar {
127
128
129
130
131
132
133
134
135
136
137
138
139
  CopyScalar() : is_ready_(true) { init_pinned_storage(); }

  void record(at::cuda::CUDAStream stream = GetCurrentStream()) {
    copy_event_.record(stream);
    is_ready_ = false;
  }

  scalar_t* get() {
    return reinterpret_cast<scalar_t*>(pinned_scalar_.data_ptr());
  }

  CopyScalar(const scalar_t* device_ptr) {
    init_pinned_storage();
140
141
142
143
    auto stream = GetCurrentStream();
    CUDA_CALL(cudaMemcpyAsync(
        reinterpret_cast<scalar_t*>(pinned_scalar_.data_ptr()), device_ptr,
        sizeof(scalar_t), cudaMemcpyDeviceToHost, stream));
144
    record(stream);
145
146
147
148
149
150
151
  }

  operator scalar_t() {
    if (!is_ready_) {
      copy_event_.synchronize();
      is_ready_ = true;
    }
152
    return *get();
153
154
155
  }

 private:
156
157
158
159
160
161
  void init_pinned_storage() {
    pinned_scalar_ = torch::empty(
        sizeof(scalar_t),
        c10::TensorOptions().dtype(torch::kBool).pinned_memory(true));
  }

162
163
164
165
166
  torch::Tensor pinned_scalar_;
  at::cuda::CUDAEvent copy_event_;
  bool is_ready_;
};

167
168
169
170
171
172
173
174
175
176
// This includes all integer, float and boolean types.
#define GRAPHBOLT_DISPATCH_CASE_ALL_TYPES(...)            \
  AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__)                 \
  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)     \
  AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
  AT_DISPATCH_CASE(at::ScalarType::Bool, __VA_ARGS__)

#define GRAPHBOLT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
  AT_DISPATCH_SWITCH(TYPE, NAME, GRAPHBOLT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__))

177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
#define GRAPHBOLT_DISPATCH_ELEMENT_SIZES(element_size, name, ...)             \
  [&] {                                                                       \
    switch (element_size) {                                                   \
      case 1: {                                                               \
        using element_size_t = uint8_t;                                       \
        return __VA_ARGS__();                                                 \
      }                                                                       \
      case 2: {                                                               \
        using element_size_t = uint16_t;                                      \
        return __VA_ARGS__();                                                 \
      }                                                                       \
      case 4: {                                                               \
        using element_size_t = uint32_t;                                      \
        return __VA_ARGS__();                                                 \
      }                                                                       \
      case 8: {                                                               \
        using element_size_t = uint64_t;                                      \
        return __VA_ARGS__();                                                 \
      }                                                                       \
      case 16: {                                                              \
        using element_size_t = float4;                                        \
        return __VA_ARGS__();                                                 \
      }                                                                       \
      default:                                                                \
        TORCH_CHECK(false, name, " with the element_size is not supported!"); \
        using element_size_t = uint8_t;                                       \
        return __VA_ARGS__();                                                 \
    }                                                                         \
  }()

207
208
209
}  // namespace cuda
}  // namespace graphbolt
#endif  // GRAPHBOLT_CUDA_COMMON_H_