"vscode:/vscode.git/clone" did not exist on "e7beba17b7025469b6c06352b38ceda509644bc9"
common.h 6.87 KB
Newer Older
1
2
/**
 *  Copyright (c) 2017-2023 by Contributors
3
 *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
4
5
6
7
8
9
 * @file cuda/common.h
 * @brief Common utilities for CUDA
 */
#ifndef GRAPHBOLT_CUDA_COMMON_H_
#define GRAPHBOLT_CUDA_COMMON_H_

10
#include <ATen/cuda/CUDAEvent.h>
11
#include <c10/cuda/CUDACachingAllocator.h>
12
13
#include <c10/cuda/CUDAException.h>
#include <cuda_runtime.h>
14
15
16
17
#include <torch/script.h>

#include <memory>
#include <unordered_map>
18
19
20
21

namespace graphbolt {
namespace cuda {

22
23
24
25
/**
 * @brief This class is designed to allocate workspace storage
 * and to get a nonblocking thrust execution policy
 * that uses torch's CUDA memory pool and the current cuda stream:
26
 *
27
28
29
 * cuda::CUDAWorkspaceAllocator allocator;
 * const auto stream = torch::cuda::getDefaultCUDAStream();
 * const auto exec_policy = thrust::cuda::par_nosync(allocator).on(stream);
30
 *
31
 * Now, one can pass exec_policy to thrust functions
32
 *
33
34
 * To get an integer array of size 1000 whose lifetime is managed by unique_ptr,
 * use:
35
 *
36
 * auto int_array = allocator.AllocateStorage<int>(1000);
37
 *
38
39
 * int_array.get() gives the raw pointer.
 */
40
struct CUDAWorkspaceAllocator {
41
42
43
  // Required by thrust to satisfy allocator requirements.
  using value_type = char;

44
  explicit CUDAWorkspaceAllocator() { at::globalContext().lazyInitCUDA(); }
45
46
47

  CUDAWorkspaceAllocator& operator=(const CUDAWorkspaceAllocator&) = default;

48
49
50
  void operator()(void* ptr) const {
    c10::cuda::CUDACachingAllocator::raw_delete(ptr);
  }
51
52
53

  // Required by thrust to satisfy allocator requirements.
  value_type* allocate(std::ptrdiff_t size) const {
54
55
    return reinterpret_cast<value_type*>(
        c10::cuda::CUDACachingAllocator::raw_alloc(size));
56
57
58
59
60
61
62
63
64
65
66
67
68
  }

  // Required by thrust to satisfy allocator requirements.
  void deallocate(value_type* ptr, std::size_t) const { operator()(ptr); }

  template <typename T>
  std::unique_ptr<T, CUDAWorkspaceAllocator> AllocateStorage(
      std::size_t size) const {
    return std::unique_ptr<T, CUDAWorkspaceAllocator>(
        reinterpret_cast<T*>(allocate(sizeof(T) * size)), *this);
  }
};

69
inline auto GetAllocator() { return CUDAWorkspaceAllocator{}; }
70
71

inline auto GetCurrentStream() { return c10::cuda::getCurrentCUDAStream(); }
72

73
74
75
76
77
78
79
80
81
82
template <typename T>
inline bool is_zero(T size) {
  return size == 0;
}

template <>
inline bool is_zero<dim3>(dim3 size) {
  return size.x == 0 || size.y == 0 || size.z == 0;
}

83
84
#define CUDA_CALL(func) C10_CUDA_CHECK((func))

85
86
87
88
89
90
91
92
93
#define CUDA_KERNEL_CALL(kernel, nblks, nthrs, shmem, stream, ...)    \
  {                                                                   \
    if (!graphbolt::cuda::is_zero((nblks)) &&                         \
        !graphbolt::cuda::is_zero((nthrs))) {                         \
      (kernel)<<<(nblks), (nthrs), (shmem), (stream)>>>(__VA_ARGS__); \
      C10_CUDA_KERNEL_LAUNCH_CHECK();                                 \
    }                                                                 \
  }

94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
/**
 * @brief This class is designed to handle the copy operation of a single
 * scalar_t item from a given CUDA device pointer. Later, if the object is cast
 * into scalar_t, the value can be read.
 *
 * auto num_edges = cuda::CopyScalar(indptr.data_ptr<scalar_t>() +
 *     indptr.size(0) - 1);
 * // Perform many operations here, they will run as normal.
 * // We finally need to read num_edges.
 * auto indices = torch::empty(static_cast<scalar_t>(num_edges));
 */
template <typename scalar_t>
struct CopyScalar {
  CopyScalar(const scalar_t* device_ptr) : is_ready_(false) {
    pinned_scalar_ = torch::empty(
        sizeof(scalar_t),
        c10::TensorOptions().dtype(torch::kBool).pinned_memory(true));
    auto stream = GetCurrentStream();
    CUDA_CALL(cudaMemcpyAsync(
        reinterpret_cast<scalar_t*>(pinned_scalar_.data_ptr()), device_ptr,
        sizeof(scalar_t), cudaMemcpyDeviceToHost, stream));
    copy_event_.record(stream);
  }

  operator scalar_t() {
    if (!is_ready_) {
      copy_event_.synchronize();
      is_ready_ = true;
    }
    return reinterpret_cast<scalar_t*>(pinned_scalar_.data_ptr())[0];
  }

 private:
  torch::Tensor pinned_scalar_;
  at::cuda::CUDAEvent copy_event_;
  bool is_ready_;
};

132
133
134
135
136
137
138
139
140
141
// This includes all integer, float and boolean types.
#define GRAPHBOLT_DISPATCH_CASE_ALL_TYPES(...)            \
  AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__)                 \
  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)     \
  AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
  AT_DISPATCH_CASE(at::ScalarType::Bool, __VA_ARGS__)

#define GRAPHBOLT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
  AT_DISPATCH_SWITCH(TYPE, NAME, GRAPHBOLT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__))

142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
#define GRAPHBOLT_DISPATCH_ELEMENT_SIZES(element_size, name, ...)             \
  [&] {                                                                       \
    switch (element_size) {                                                   \
      case 1: {                                                               \
        using element_size_t = uint8_t;                                       \
        return __VA_ARGS__();                                                 \
      }                                                                       \
      case 2: {                                                               \
        using element_size_t = uint16_t;                                      \
        return __VA_ARGS__();                                                 \
      }                                                                       \
      case 4: {                                                               \
        using element_size_t = uint32_t;                                      \
        return __VA_ARGS__();                                                 \
      }                                                                       \
      case 8: {                                                               \
        using element_size_t = uint64_t;                                      \
        return __VA_ARGS__();                                                 \
      }                                                                       \
      case 16: {                                                              \
        using element_size_t = float4;                                        \
        return __VA_ARGS__();                                                 \
      }                                                                       \
      default:                                                                \
        TORCH_CHECK(false, name, " with the element_size is not supported!"); \
        using element_size_t = uint8_t;                                       \
        return __VA_ARGS__();                                                 \
    }                                                                         \
  }()

172
173
174
}  // namespace cuda
}  // namespace graphbolt
#endif  // GRAPHBOLT_CUDA_COMMON_H_