cuda_device_api.cc 10.2 KB
Newer Older
1
/*!
2
 *  Copyright (c) 2017-2022 by Contributors
3
4
5
6
 * \file cuda_device_api.cc
 * \brief GPU specific API
 */
#include <dgl/runtime/device_api.h>
7
#include <dgl/runtime/tensordispatch.h>
8
9
10
11
12
13
14
15
16
17
#include <dmlc/thread_local.h>
#include <dgl/runtime/registry.h>
#include <cuda_runtime.h>
#include "cuda_common.h"

namespace dgl {
namespace runtime {

class CUDADeviceAPI final : public DeviceAPI {
 public:
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
  CUDADeviceAPI() {
    int count;
    auto err = cudaGetDeviceCount(&count);
    switch (err) {
      case cudaSuccess:
        break;
      default:
        count = 0;
        cudaGetLastError();
    }
    is_available_ = count > 0;
  }

  bool IsAvailable() final {
    return is_available_;
  }

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
  void SetDevice(DGLContext ctx) final {
    CUDA_CALL(cudaSetDevice(ctx.device_id));
  }
  void GetAttr(DGLContext ctx, DeviceAttrKind kind, DGLRetValue* rv) final {
    int value = 0;
    switch (kind) {
      case kExist:
        value = (
            cudaDeviceGetAttribute(
                &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id)
            == cudaSuccess);
        break;
      case kMaxThreadsPerBlock: {
        CUDA_CALL(cudaDeviceGetAttribute(
            &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id));
        break;
      }
      case kWarpSize: {
        CUDA_CALL(cudaDeviceGetAttribute(
            &value, cudaDevAttrWarpSize, ctx.device_id));
        break;
      }
      case kMaxSharedMemoryPerBlock: {
        CUDA_CALL(cudaDeviceGetAttribute(
            &value, cudaDevAttrMaxSharedMemoryPerBlock, ctx.device_id));
        break;
      }
      case kComputeVersion: {
        std::ostringstream os;
        CUDA_CALL(cudaDeviceGetAttribute(
            &value, cudaDevAttrComputeCapabilityMajor, ctx.device_id));
        os << value << ".";
        CUDA_CALL(cudaDeviceGetAttribute(
            &value, cudaDevAttrComputeCapabilityMinor, ctx.device_id));
        os << value;
        *rv = os.str();
        return;
      }
      case kDeviceName: {
        cudaDeviceProp props;
        CUDA_CALL(cudaGetDeviceProperties(&props, ctx.device_id));
        *rv = std::string(props.name);
        return;
      }
      case kMaxClockRate: {
        CUDA_CALL(cudaDeviceGetAttribute(
            &value, cudaDevAttrClockRate, ctx.device_id));
        break;
      }
      case kMultiProcessorCount: {
        CUDA_CALL(cudaDeviceGetAttribute(
            &value, cudaDevAttrMultiProcessorCount, ctx.device_id));
        break;
      }
      case kMaxThreadDimensions: {
        int dims[3];
        CUDA_CALL(cudaDeviceGetAttribute(
            &dims[0], cudaDevAttrMaxBlockDimX, ctx.device_id));
        CUDA_CALL(cudaDeviceGetAttribute(
            &dims[1], cudaDevAttrMaxBlockDimY, ctx.device_id));
        CUDA_CALL(cudaDeviceGetAttribute(
            &dims[2], cudaDevAttrMaxBlockDimZ, ctx.device_id));

        std::stringstream ss;  // use json string to return multiple int values;
        ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]";
        *rv = ss.str();
        return;
      }
    }
    *rv = value;
  }
  void* AllocDataSpace(DGLContext ctx,
                       size_t nbytes,
                       size_t alignment,
                       DGLType type_hint) final {
110
111
112
113
114
115
    SetDevice(ctx);
    // Redirect to PyTorch's allocator when available.
    TensorDispatcher* td = TensorDispatcher::Global();
    if (td->IsAvailable())
      return td->CUDAAllocWorkspace(nbytes, CUDAThreadEntry::ThreadLocal()->stream);

116
117
118
119
120
121
122
123
    CHECK_EQ(256 % alignment, 0U)
        << "CUDA space is aligned at 256 bytes";
    void *ret;
    CUDA_CALL(cudaMalloc(&ret, nbytes));
    return ret;
  }

  void FreeDataSpace(DGLContext ctx, void* ptr) final {
124
125
126
127
128
    SetDevice(ctx);
    TensorDispatcher* td = TensorDispatcher::Global();
    if (td->IsAvailable())
      return td->CUDAFreeWorkspace(ptr);

129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    CUDA_CALL(cudaFree(ptr));
  }

  void CopyDataFromTo(const void* from,
                      size_t from_offset,
                      void* to,
                      size_t to_offset,
                      size_t size,
                      DGLContext ctx_from,
                      DGLContext ctx_to,
                      DGLType type_hint,
                      DGLStreamHandle stream) final {
    cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
    from = static_cast<const char*>(from) + from_offset;
    to = static_cast<char*>(to) + to_offset;
    if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLGPU) {
      CUDA_CALL(cudaSetDevice(ctx_from.device_id));
      if (ctx_from.device_id == ctx_to.device_id) {
        GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream);
      } else {
149
150
151
        CUDA_CALL(cudaMemcpyPeerAsync(to, ctx_to.device_id,
                                      from, ctx_from.device_id,
                                      size, cu_stream));
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
      }
    } else if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLCPU) {
      CUDA_CALL(cudaSetDevice(ctx_from.device_id));
      GPUCopy(from, to, size, cudaMemcpyDeviceToHost, cu_stream);
    } else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLGPU) {
      CUDA_CALL(cudaSetDevice(ctx_to.device_id));
      GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream);
    } else {
      LOG(FATAL) << "expect copy from/to GPU or between GPU";
    }
  }

  DGLStreamHandle CreateStream(DGLContext ctx) {
    CUDA_CALL(cudaSetDevice(ctx.device_id));
    cudaStream_t retval;
167
168
    // make sure the legacy default stream won't block on this stream
    CUDA_CALL(cudaStreamCreateWithFlags(&retval, cudaStreamNonBlocking));
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    return static_cast<DGLStreamHandle>(retval);
  }

  void FreeStream(DGLContext ctx, DGLStreamHandle stream) {
    CUDA_CALL(cudaSetDevice(ctx.device_id));
    cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
    CUDA_CALL(cudaStreamDestroy(cu_stream));
  }

  void SyncStreamFromTo(DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst) {
    CUDA_CALL(cudaSetDevice(ctx.device_id));
    cudaStream_t src_stream = static_cast<cudaStream_t>(event_src);
    cudaStream_t dst_stream = static_cast<cudaStream_t>(event_dst);
    cudaEvent_t evt;
    CUDA_CALL(cudaEventCreate(&evt));
    CUDA_CALL(cudaEventRecord(evt, src_stream));
    CUDA_CALL(cudaStreamWaitEvent(dst_stream, evt, 0));
    CUDA_CALL(cudaEventDestroy(evt));
  }

  void StreamSync(DGLContext ctx, DGLStreamHandle stream) final {
    CUDA_CALL(cudaSetDevice(ctx.device_id));
    CUDA_CALL(cudaStreamSynchronize(static_cast<cudaStream_t>(stream)));
  }

  void SetStream(DGLContext ctx, DGLStreamHandle stream) final {
    CUDAThreadEntry::ThreadLocal()
        ->stream = static_cast<cudaStream_t>(stream);
  }

199
200
201
202
  DGLStreamHandle GetStream() const final {
    return static_cast<DGLStreamHandle>(CUDAThreadEntry::ThreadLocal()->stream);
  }

203
204
205
206
207
208
  /*! NOTE: cudaHostRegister can be called from an arbitrary GPU device,
   *        so we don't need to specify a ctx.
   *        The pinned memory can be seen by all CUDA contexts,
   *        not just the one that performed the allocation
   */
  void PinData(void* ptr, size_t nbytes) {
209
210
211
    // prevent users from pinning empty tensors or graphs
    if (ptr == nullptr || nbytes == 0)
      return;
212
213
214
    CUDA_CALL(cudaHostRegister(ptr, nbytes, cudaHostRegisterDefault));
  }

215
  void UnpinData(void* ptr) {
216
217
    if (ptr == nullptr)
      return;
218
219
220
    CUDA_CALL(cudaHostUnregister(ptr));
  }

221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
  bool IsPinned(const void* ptr) override {
    // can't be a pinned tensor if CUDA context is unavailable.
    if (!is_available_)
      return false;

    cudaPointerAttributes attr;
    cudaError_t status = cudaPointerGetAttributes(&attr, ptr);
    bool result = false;

    switch (status) {
    case cudaErrorInvalidValue:
      // might be a normal CPU tensor in CUDA 10.2-
      cudaGetLastError();   // clear error
      break;
    case cudaSuccess:
      result = (attr.type == cudaMemoryTypeHost);
      break;
    case cudaErrorInitializationError:
    case cudaErrorNoDevice:
    case cudaErrorInsufficientDriver:
    case cudaErrorInvalidDevice:
      // We don't want to fail in these particular cases since this function can be called
      // when users only want to run on CPU even if CUDA API is enabled, or in a forked
      // subprocess where CUDA context cannot be initialized.  So we just mark the CUDA
      // context to unavailable and return.
      is_available_ = false;
      cudaGetLastError();   // clear error
      break;
    default:
      LOG(FATAL) << "error while determining memory status: " << cudaGetErrorString(status);
      break;
    }

    return result;
  }

257
  void* AllocWorkspace(DGLContext ctx, size_t size, DGLType type_hint) final {
258
    SetDevice(ctx);
259
    // Redirect to PyTorch's allocator when available.
260
261
    TensorDispatcher* td = TensorDispatcher::Global();
    if (td->IsAvailable())
262
263
264
      return td->CUDAAllocWorkspace(size, CUDAThreadEntry::ThreadLocal()->stream);

    return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
265
266
267
  }

  void FreeWorkspace(DGLContext ctx, void* data) final {
268
    SetDevice(ctx);
269
270
    TensorDispatcher* td = TensorDispatcher::Global();
    if (td->IsAvailable())
271
272
273
      return td->CUDAFreeWorkspace(data);

    CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
274
275
276
277
278
279
280
281
282
283
284
285
286
287
  }

  static const std::shared_ptr<CUDADeviceAPI>& Global() {
    static std::shared_ptr<CUDADeviceAPI> inst =
        std::make_shared<CUDADeviceAPI>();
    return inst;
  }

 private:
  static void GPUCopy(const void* from,
                      void* to,
                      size_t size,
                      cudaMemcpyKind kind,
                      cudaStream_t stream) {
288
289
290
291
    CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream));
    if (stream == 0 && kind == cudaMemcpyDeviceToHost) {
      // only wait for the copy, when it's on the default stream, and it's to host memory
      CUDA_CALL(cudaStreamSynchronize(stream));
292
293
    }
  }
294
295

  bool is_available_ = true;
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
};

typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore;

CUDAThreadEntry::CUDAThreadEntry()
    : pool(kDLGPU, CUDADeviceAPI::Global()) {
}

CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() {
  return CUDAThreadStore::Get();
}

DGL_REGISTER_GLOBAL("device_api.gpu")
.set_body([](DGLArgs args, DGLRetValue* rv) {
    DeviceAPI* ptr = CUDADeviceAPI::Global().get();
    *rv = static_cast<void*>(ptr);
  });

}  // namespace runtime
}  // namespace dgl