/*! * Copyright (c) 2017 by Contributors * \file cuda_device_api.cc * \brief GPU specific API */ #include #include #include #include #include #include "cuda_common.h" namespace dgl { namespace runtime { class CUDADeviceAPI final : public DeviceAPI { public: CUDADeviceAPI() { int count; auto err = cudaGetDeviceCount(&count); switch (err) { case cudaSuccess: break; default: count = 0; cudaGetLastError(); } is_available_ = count > 0; } bool IsAvailable() final { return is_available_; } void SetDevice(DGLContext ctx) final { CUDA_CALL(cudaSetDevice(ctx.device_id)); } void GetAttr(DGLContext ctx, DeviceAttrKind kind, DGLRetValue* rv) final { int value = 0; switch (kind) { case kExist: value = ( cudaDeviceGetAttribute( &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id) == cudaSuccess); break; case kMaxThreadsPerBlock: { CUDA_CALL(cudaDeviceGetAttribute( &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id)); break; } case kWarpSize: { CUDA_CALL(cudaDeviceGetAttribute( &value, cudaDevAttrWarpSize, ctx.device_id)); break; } case kMaxSharedMemoryPerBlock: { CUDA_CALL(cudaDeviceGetAttribute( &value, cudaDevAttrMaxSharedMemoryPerBlock, ctx.device_id)); break; } case kComputeVersion: { std::ostringstream os; CUDA_CALL(cudaDeviceGetAttribute( &value, cudaDevAttrComputeCapabilityMajor, ctx.device_id)); os << value << "."; CUDA_CALL(cudaDeviceGetAttribute( &value, cudaDevAttrComputeCapabilityMinor, ctx.device_id)); os << value; *rv = os.str(); return; } case kDeviceName: { cudaDeviceProp props; CUDA_CALL(cudaGetDeviceProperties(&props, ctx.device_id)); *rv = std::string(props.name); return; } case kMaxClockRate: { CUDA_CALL(cudaDeviceGetAttribute( &value, cudaDevAttrClockRate, ctx.device_id)); break; } case kMultiProcessorCount: { CUDA_CALL(cudaDeviceGetAttribute( &value, cudaDevAttrMultiProcessorCount, ctx.device_id)); break; } case kMaxThreadDimensions: { int dims[3]; CUDA_CALL(cudaDeviceGetAttribute( &dims[0], cudaDevAttrMaxBlockDimX, ctx.device_id)); CUDA_CALL(cudaDeviceGetAttribute( &dims[1], cudaDevAttrMaxBlockDimY, ctx.device_id)); CUDA_CALL(cudaDeviceGetAttribute( &dims[2], cudaDevAttrMaxBlockDimZ, ctx.device_id)); std::stringstream ss; // use json string to return multiple int values; ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]"; *rv = ss.str(); return; } } *rv = value; } void* AllocDataSpace(DGLContext ctx, size_t nbytes, size_t alignment, DGLType type_hint) final { CUDA_CALL(cudaSetDevice(ctx.device_id)); CHECK_EQ(256 % alignment, 0U) << "CUDA space is aligned at 256 bytes"; void *ret; CUDA_CALL(cudaMalloc(&ret, nbytes)); return ret; } void FreeDataSpace(DGLContext ctx, void* ptr) final { CUDA_CALL(cudaSetDevice(ctx.device_id)); CUDA_CALL(cudaFree(ptr)); } void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, DGLContext ctx_from, DGLContext ctx_to, DGLType type_hint, DGLStreamHandle stream) final { cudaStream_t cu_stream = static_cast(stream); from = static_cast(from) + from_offset; to = static_cast(to) + to_offset; if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLGPU) { CUDA_CALL(cudaSetDevice(ctx_from.device_id)); if (ctx_from.device_id == ctx_to.device_id) { GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream); } else { CUDA_CALL(cudaMemcpyPeerAsync(to, ctx_to.device_id, from, ctx_from.device_id, size, cu_stream)); } } else if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLCPU) { CUDA_CALL(cudaSetDevice(ctx_from.device_id)); GPUCopy(from, to, size, cudaMemcpyDeviceToHost, cu_stream); } else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLGPU) { CUDA_CALL(cudaSetDevice(ctx_to.device_id)); GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream); } else { LOG(FATAL) << "expect copy from/to GPU or between GPU"; } } DGLStreamHandle CreateStream(DGLContext ctx) { CUDA_CALL(cudaSetDevice(ctx.device_id)); cudaStream_t retval; // make sure the legacy default stream won't block on this stream CUDA_CALL(cudaStreamCreateWithFlags(&retval, cudaStreamNonBlocking)); return static_cast(retval); } void FreeStream(DGLContext ctx, DGLStreamHandle stream) { CUDA_CALL(cudaSetDevice(ctx.device_id)); cudaStream_t cu_stream = static_cast(stream); CUDA_CALL(cudaStreamDestroy(cu_stream)); } void SyncStreamFromTo(DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst) { CUDA_CALL(cudaSetDevice(ctx.device_id)); cudaStream_t src_stream = static_cast(event_src); cudaStream_t dst_stream = static_cast(event_dst); cudaEvent_t evt; CUDA_CALL(cudaEventCreate(&evt)); CUDA_CALL(cudaEventRecord(evt, src_stream)); CUDA_CALL(cudaStreamWaitEvent(dst_stream, evt, 0)); CUDA_CALL(cudaEventDestroy(evt)); } void StreamSync(DGLContext ctx, DGLStreamHandle stream) final { CUDA_CALL(cudaSetDevice(ctx.device_id)); CUDA_CALL(cudaStreamSynchronize(static_cast(stream))); } void SetStream(DGLContext ctx, DGLStreamHandle stream) final { CUDAThreadEntry::ThreadLocal() ->stream = static_cast(stream); } DGLStreamHandle GetStream() const final { return static_cast(CUDAThreadEntry::ThreadLocal()->stream); } /*! NOTE: cudaHostRegister can be called from an arbitrary GPU device, * so we don't need to specify a ctx. * The pinned memory can be seen by all CUDA contexts, * not just the one that performed the allocation */ void PinData(void* ptr, size_t nbytes) { // prevent users from pinning empty tensors or graphs if (ptr == nullptr || nbytes == 0) return; CUDA_CALL(cudaHostRegister(ptr, nbytes, cudaHostRegisterDefault)); } void UnpinData(void* ptr) { if (ptr == nullptr) return; CUDA_CALL(cudaHostUnregister(ptr)); } bool IsPinned(const void* ptr) override { // can't be a pinned tensor if CUDA context is unavailable. if (!is_available_) return false; cudaPointerAttributes attr; cudaError_t status = cudaPointerGetAttributes(&attr, ptr); bool result = false; switch (status) { case cudaErrorInvalidValue: // might be a normal CPU tensor in CUDA 10.2- cudaGetLastError(); // clear error break; case cudaSuccess: result = (attr.type == cudaMemoryTypeHost); break; case cudaErrorInitializationError: case cudaErrorNoDevice: case cudaErrorInsufficientDriver: case cudaErrorInvalidDevice: // We don't want to fail in these particular cases since this function can be called // when users only want to run on CPU even if CUDA API is enabled, or in a forked // subprocess where CUDA context cannot be initialized. So we just mark the CUDA // context to unavailable and return. is_available_ = false; cudaGetLastError(); // clear error break; default: LOG(FATAL) << "error while determining memory status: " << cudaGetErrorString(status); break; } return result; } void* AllocWorkspace(DGLContext ctx, size_t size, DGLType type_hint) final { // Redirect to PyTorch's allocator when available. SetDevice(ctx); TensorDispatcher* td = TensorDispatcher::Global(); if (td->IsAvailable()) return td->AllocWorkspace(size); else return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size); } void FreeWorkspace(DGLContext ctx, void* data) final { TensorDispatcher* td = TensorDispatcher::Global(); if (td->IsAvailable()) td->FreeWorkspace(data); else CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data); } static const std::shared_ptr& Global() { static std::shared_ptr inst = std::make_shared(); return inst; } private: static void GPUCopy(const void* from, void* to, size_t size, cudaMemcpyKind kind, cudaStream_t stream) { CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream)); if (stream == 0 && kind == cudaMemcpyDeviceToHost) { // only wait for the copy, when it's on the default stream, and it's to host memory CUDA_CALL(cudaStreamSynchronize(stream)); } } bool is_available_ = true; }; typedef dmlc::ThreadLocalStore CUDAThreadStore; CUDAThreadEntry::CUDAThreadEntry() : pool(kDLGPU, CUDADeviceAPI::Global()) { } CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { return CUDAThreadStore::Get(); } DGL_REGISTER_GLOBAL("device_api.gpu") .set_body([](DGLArgs args, DGLRetValue* rv) { DeviceAPI* ptr = CUDADeviceAPI::Global().get(); *rv = static_cast(ptr); }); } // namespace runtime } // namespace dgl