cuda_device_api.cc 11.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
/*!
 *  Copyright (c) 2017-2022 by Contributors
 * \file cuda_device_api.cc
 * \brief GPU specific API
 */
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/tensordispatch.h>
#include <dmlc/thread_local.h>
#include <dgl/runtime/registry.h>
#include <hip/hip_runtime.h>
#include "cuda_common.h"

namespace dgl {
namespace runtime {

class CUDADeviceAPI final : public DeviceAPI {
 public:
  CUDADeviceAPI() {
    int count;
    auto err = hipGetDeviceCount(&count);
    switch (err) {
      case hipSuccess:
        break;
      default:
        count = 0;
        hipGetLastError();
    }
    is_available_ = count > 0;
  }

  bool IsAvailable() final {
    return is_available_;
  }

  void SetDevice(DGLContext ctx) final {
    CUDA_CALL(hipSetDevice(ctx.device_id));
  }
  void GetAttr(DGLContext ctx, DeviceAttrKind kind, DGLRetValue* rv) final {
    int value = 0;
    switch (kind) {
      case kExist:
        value = (
            hipDeviceGetAttribute(
                &value, hipDeviceAttributeMaxThreadsPerBlock, ctx.device_id)
            == hipSuccess);
        break;
      case kMaxThreadsPerBlock: {
        CUDA_CALL(hipDeviceGetAttribute(
            &value, hipDeviceAttributeMaxThreadsPerBlock, ctx.device_id));
        break;
      }
      case kWarpSize: {
        CUDA_CALL(hipDeviceGetAttribute(
            &value, hipDeviceAttributeWarpSize, ctx.device_id));
        break;
      }
      case kMaxSharedMemoryPerBlock: {
        CUDA_CALL(hipDeviceGetAttribute(
            &value, hipDeviceAttributeMaxSharedMemoryPerBlock, ctx.device_id));
        break;
      }
      case kComputeVersion: {
        std::ostringstream os;
        CUDA_CALL(hipDeviceGetAttribute(
            &value, hipDeviceAttributeComputeCapabilityMajor, ctx.device_id));
        os << value << ".";
        CUDA_CALL(hipDeviceGetAttribute(
            &value, hipDeviceAttributeComputeCapabilityMinor, ctx.device_id));
        os << value;
        *rv = os.str();
        return;
      }
      case kDeviceName: {
        hipDeviceProp_t props;
        CUDA_CALL(hipGetDeviceProperties(&props, ctx.device_id));
        *rv = std::string(props.name);
        return;
      }
      case kMaxClockRate: {
        CUDA_CALL(hipDeviceGetAttribute(
            &value, hipDeviceAttributeClockRate, ctx.device_id));
        break;
      }
      case kMultiProcessorCount: {
        CUDA_CALL(hipDeviceGetAttribute(
            &value, hipDeviceAttributeMultiprocessorCount, ctx.device_id));
        break;
      }
      case kMaxThreadDimensions: {
        int dims[3];
        CUDA_CALL(hipDeviceGetAttribute(
            &dims[0], hipDeviceAttributeMaxBlockDimX, ctx.device_id));
        CUDA_CALL(hipDeviceGetAttribute(
            &dims[1], hipDeviceAttributeMaxBlockDimY, ctx.device_id));
        CUDA_CALL(hipDeviceGetAttribute(
            &dims[2], hipDeviceAttributeMaxBlockDimZ, ctx.device_id));

        std::stringstream ss;  // use json string to return multiple int values;
        ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]";
        *rv = ss.str();
        return;
      }
    }
    *rv = value;
  }
  void* AllocDataSpace(DGLContext ctx,
                       size_t nbytes,
                       size_t alignment,
                       DGLType type_hint) final {
    SetDevice(ctx);
    // Redirect to PyTorch's allocator when available.
    TensorDispatcher* td = TensorDispatcher::Global();
    if (td->IsAvailable())
      return td->CUDAAllocWorkspace(nbytes, getCurrentCUDAStream());

    CHECK_EQ(256 % alignment, 0U)
        << "CUDA space is aligned at 256 bytes";
    void *ret;
    CUDA_CALL(hipMalloc(&ret, nbytes));
    return ret;
  }

  void FreeDataSpace(DGLContext ctx, void* ptr) final {
    SetDevice(ctx);
    TensorDispatcher* td = TensorDispatcher::Global();
    if (td->IsAvailable())
      return td->CUDAFreeWorkspace(ptr);

    CUDA_CALL(hipFree(ptr));
  }

  void CopyDataFromTo(const void* from,
                      size_t from_offset,
                      void* to,
                      size_t to_offset,
                      size_t size,
                      DGLContext ctx_from,
                      DGLContext ctx_to,
                      DGLType type_hint,
                      DGLStreamHandle stream) {
    hipStream_t cu_stream = static_cast<hipStream_t>(stream);
    from = static_cast<const char*>(from) + from_offset;
    to = static_cast<char*>(to) + to_offset;
lisj's avatar
lisj committed
144
    if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLROCM) {
145
146
147
148
149
150
151
152
      CUDA_CALL(hipSetDevice(ctx_from.device_id));
      if (ctx_from.device_id == ctx_to.device_id) {
        GPUCopy(from, to, size, hipMemcpyDeviceToDevice, cu_stream);
      } else {
        CUDA_CALL(hipMemcpyPeerAsync(to, ctx_to.device_id,
                                      from, ctx_from.device_id,
                                      size, cu_stream));
      }
lisj's avatar
lisj committed
153
    } else if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLCPU) {
154
155
      CUDA_CALL(hipSetDevice(ctx_from.device_id));
      GPUCopy(from, to, size, hipMemcpyDeviceToHost, cu_stream);
lisj's avatar
lisj committed
156
    } else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLROCM) {
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
      CUDA_CALL(hipSetDevice(ctx_to.device_id));
      GPUCopy(from, to, size, hipMemcpyHostToDevice, cu_stream);
    } else {
      LOG(FATAL) << "expect copy from/to GPU or between GPU";
    }
  }

  void CopyDataFromTo(const void* from,
                      size_t from_offset,
                      void* to,
                      size_t to_offset,
                      size_t size,
                      DGLContext ctx_from,
                      DGLContext ctx_to,
                      DGLType type_hint) final {
    auto stream = GetStream();
    CopyDataFromTo(from, from_offset, to, to_offset, size, ctx_from, ctx_to, type_hint, stream);
  }

  DGLStreamHandle CreateStream(DGLContext ctx) override {
    CUDA_CALL(hipSetDevice(ctx.device_id));
    hipStream_t retval;
    // make sure the legacy default stream won't block on this stream
    CUDA_CALL(hipStreamCreateWithFlags(&retval, hipStreamNonBlocking));
    return static_cast<DGLStreamHandle>(retval);
  }

  void FreeStream(DGLContext ctx, DGLStreamHandle stream) override {
    CUDA_CALL(hipSetDevice(ctx.device_id));
    hipStream_t cu_stream = static_cast<hipStream_t>(stream);
    CUDA_CALL(hipStreamDestroy(cu_stream));
  }

  void SyncStreamFromTo(DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst) override {
    CUDA_CALL(hipSetDevice(ctx.device_id));
    hipStream_t src_stream = static_cast<hipStream_t>(event_src);
    hipStream_t dst_stream = static_cast<hipStream_t>(event_dst);
    hipEvent_t evt;
    CUDA_CALL(hipEventCreate(&evt));
    CUDA_CALL(hipEventRecord(evt, src_stream));
    CUDA_CALL(hipStreamWaitEvent(dst_stream, evt, 0));
    CUDA_CALL(hipEventDestroy(evt));
  }

  void StreamSync(DGLContext ctx, DGLStreamHandle stream) final {
    CUDA_CALL(hipSetDevice(ctx.device_id));
    CUDA_CALL(hipStreamSynchronize(static_cast<hipStream_t>(stream)));
  }

  /*! NOTE: If the backend is PyTorch, we will use PyTorch's stream management,
   *        so just avoid calling our SetStream/CreateStream unless
   *        you really need advanced stream control.
   * TODO(Xin): Redirect this to PyTorch or remove it.
   * PyTorch allows external CUDA streams to be set as current since v1.11.
   */
  void SetStream(DGLContext ctx, DGLStreamHandle stream) final {}

  DGLStreamHandle GetStream() const final {
    return static_cast<DGLStreamHandle>(getCurrentCUDAStream());
  }

  /*! NOTE: hipHostRegister can be called from an arbitrary GPU device,
   *        so we don't need to specify a ctx.
   *        The pinned memory can be seen by all CUDA contexts,
   *        not just the one that performed the allocation
   */
  void PinData(void* ptr, size_t nbytes) override {
    // prevent users from pinning empty tensors or graphs
    if (ptr == nullptr || nbytes == 0)
      return;
    CUDA_CALL(hipHostRegister(ptr, nbytes, hipHostRegisterDefault));
  }

  void UnpinData(void* ptr) override {
    if (ptr == nullptr)
      return;
    CUDA_CALL(hipHostUnregister(ptr));
  }

  bool IsPinned(const void* ptr) override {
    // can't be a pinned tensor if CUDA context is unavailable.
    if (!is_available_)
      return false;

    hipPointerAttribute_t attr;
    hipError_t status = hipPointerGetAttributes(&attr, ptr);
    bool result = false;

    switch (status) {
    case hipErrorInvalidValue:
      // might be a normal CPU tensor in CUDA 10.2-
      hipGetLastError();   // clear error
      break;
    case hipSuccess:
      // result = (attr.type == cudaMemoryTypeHost);
      result = (attr.memoryType == hipMemoryTypeHost);
      break;
    case hipErrorNotInitialized:
    case hipErrorNoDevice:
    case hipErrorInsufficientDriver:
    case hipErrorInvalidDevice:
      // We don't want to fail in these particular cases since this function can be called
      // when users only want to run on CPU even if CUDA API is enabled, or in a forked
      // subprocess where CUDA context cannot be initialized.  So we just mark the CUDA
      // context to unavailable and return.
      is_available_ = false;
      hipGetLastError();   // clear error
      break;
    default:
      LOG(FATAL) << "error while determining memory status: " << hipGetErrorString(status);
      break;
    }

    return result;
  }

  void* AllocWorkspace(DGLContext ctx, size_t size, DGLType type_hint) final {
    SetDevice(ctx);
    // Redirect to PyTorch's allocator when available.
    TensorDispatcher* td = TensorDispatcher::Global();
    if (td->IsAvailable())
      return td->CUDAAllocWorkspace(size, getCurrentCUDAStream());

    return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
  }

  void FreeWorkspace(DGLContext ctx, void* data) final {
    SetDevice(ctx);
    TensorDispatcher* td = TensorDispatcher::Global();
    if (td->IsAvailable())
      return td->CUDAFreeWorkspace(data);

    CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
  }

  static const std::shared_ptr<CUDADeviceAPI>& Global() {
    static std::shared_ptr<CUDADeviceAPI> inst =
        std::make_shared<CUDADeviceAPI>();
    return inst;
  }

 private:
  static void GPUCopy(const void* from,
                      void* to,
                      size_t size,
                      hipMemcpyKind kind,
                      hipStream_t stream) {
    CUDA_CALL(hipMemcpyAsync(to, from, size, kind, stream));
    if (stream == 0 && kind == hipMemcpyDeviceToHost) {
      // only wait for the copy, when it's on the default stream, and it's to host memory
      CUDA_CALL(hipStreamSynchronize(stream));
    }
  }

  bool is_available_ = true;
};

typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore;

CUDAThreadEntry::CUDAThreadEntry()
lisj's avatar
lisj committed
317
    : pool(kDLROCM, CUDADeviceAPI::Global()) {
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
}

CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() {
  return CUDAThreadStore::Get();
}

hipStream_t getCurrentCUDAStream() {
  TensorDispatcher* td = TensorDispatcher::Global();
  if (td->IsAvailable())
    return td->CUDAGetCurrentStream();
  else  // return the default stream when TA is not available
    return nullptr;
}

DGL_REGISTER_GLOBAL("device_api.gpu")
.set_body([](DGLArgs args, DGLRetValue* rv) {
    DeviceAPI* ptr = CUDADeviceAPI::Global().get();
    *rv = static_cast<void*>(ptr);
  });

}  // namespace runtime
}  // namespace dgl