Commit 74d88bf8 authored by sangwz's avatar sangwz
Browse files

Merge branch 'dtk25.04' of http://developer.sourcefind.cn/codes/OpenDAS/dgl into 2.2.1

parents 2a1ac588 314cedc1
// !!! This is a file automatically generated by hipify!!!
/** /**
* Copyright (c) 2023 by Contributors * Copyright (c) 2023 by Contributors
* *
...@@ -10,9 +11,9 @@ ...@@ -10,9 +11,9 @@
#include <unordered_map> #include <unordered_map>
#include "./concurrent_id_hash_map.h" #include "concurrent_id_hash_map.h"
#include "./macro.h" #include "macro.h"
#include "./utils.h" #include "utils.h"
namespace graphbolt { namespace graphbolt {
namespace sampling { namespace sampling {
......
// !!! This is a file automatically generated by hipify!!!
/** /**
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* @file dgl/array.h * @file dgl/array.h
...@@ -8,10 +9,10 @@ ...@@ -8,10 +9,10 @@
*/ */
#ifndef DGL_ARRAY_H_ #ifndef DGL_ARRAY_H_
#define DGL_ARRAY_H_ #define DGL_ARRAY_H_
#include "./aten/array_ops.h" #include "aten/array_ops.h"
#include "./aten/coo.h" #include "aten/coo.h"
#include "./aten/csr.h" #include "aten/csr.h"
#include "./aten/macro.h" #include "aten/macro.h"
#include "./aten/spmat.h" #include "aten/spmat.h"
#include "./aten/types.h" #include "aten/types.h"
#endif // DGL_ARRAY_H_ #endif // DGL_ARRAY_H_
// !!! This is a file automatically generated by hipify!!!
/** /**
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* @file dgl/array_iterator.h * @file dgl/array_iterator.h
...@@ -6,11 +7,11 @@ ...@@ -6,11 +7,11 @@
#ifndef DGL_ARRAY_ITERATOR_H_ #ifndef DGL_ARRAY_ITERATOR_H_
#define DGL_ARRAY_ITERATOR_H_ #define DGL_ARRAY_ITERATOR_H_
#ifdef __CUDA_ARCH__ #ifdef __HIPCC__
#define CUB_INLINE __host__ __device__ __forceinline__ #define CUB_INLINE __host__ __device__ __forceinline__
#else #else
#define CUB_INLINE inline #define CUB_INLINE inline
#endif // __CUDA_ARCH__ #endif // __HIP_DEVICE_COMPILE__
#include <algorithm> #include <algorithm>
#include <iterator> #include <iterator>
......
// !!! This is a file automatically generated by hipify!!!
/** /**
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* @file dgl/aten/array_ops.h * @file dgl/aten/array_ops.h
...@@ -15,7 +16,7 @@ ...@@ -15,7 +16,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "./types.h" #include "types.h"
namespace dgl { namespace dgl {
namespace aten { namespace aten {
......
// !!! This is a file automatically generated by hipify!!!
/** /**
* Copyright (c) 2020-2022 by Contributors * Copyright (c) 2020-2022 by Contributors
...@@ -15,10 +16,10 @@ ...@@ -15,10 +16,10 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "./array_ops.h" #include "array_ops.h"
#include "./macro.h" #include "macro.h"
#include "./spmat.h" #include "spmat.h"
#include "./types.h" #include "types.h"
namespace dgl { namespace dgl {
namespace aten { namespace aten {
......
// !!! This is a file automatically generated by hipify!!!
/** /**
* Copyright (c) 2020-2022 by Contributors * Copyright (c) 2020-2022 by Contributors
* @file dgl/aten/csr.h * @file dgl/aten/csr.h
...@@ -14,10 +15,10 @@ ...@@ -14,10 +15,10 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "./array_ops.h" #include "array_ops.h"
#include "./macro.h" #include "macro.h"
#include "./spmat.h" #include "spmat.h"
#include "./types.h" #include "types.h"
namespace dgl { namespace dgl {
namespace aten { namespace aten {
......
// !!! This is a file automatically generated by hipify!!!
/** /**
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* @file dgl/aten/macro.h * @file dgl/aten/macro.h
...@@ -47,7 +48,7 @@ ...@@ -47,7 +48,7 @@
if ((val) == kDGLCPU) { \ if ((val) == kDGLCPU) { \
constexpr auto XPU = kDGLCPU; \ constexpr auto XPU = kDGLCPU; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ((val) == kDGLCUDA) { \ } else if ((val) == kDGLCUDA or (val) == kDGLROCM) { \
constexpr auto XPU = kDGLCUDA; \ constexpr auto XPU = kDGLCUDA; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else { \ } else { \
...@@ -145,12 +146,12 @@ ...@@ -145,12 +146,12 @@
typedef double FloatType; \ typedef double FloatType; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ( \ } else if ( \
XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLFloat) { \ (XPU == kDGLCUDA || XPU == kDGLROCM)&&(val).bits == 16 && (val).code == kDGLFloat) { \
typedef __half FloatType; \ typedef __half FloatType; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ( \ } else if ( \
XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) { \ (XPU == kDGLCUDA || XPU == kDGLROCM) && (val).bits == 16 && (val).code == kDGLBfloat) { \
typedef __nv_bfloat16 FloatType; \ typedef __hip_bfloat16 FloatType; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ( \ } else if ( \
XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLFloat) { \ XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLFloat) { \
...@@ -176,11 +177,11 @@ ...@@ -176,11 +177,11 @@
typedef double FloatType; \ typedef double FloatType; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ( \ } else if ( \
XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLFloat) { \ (XPU == kDGLCUDA || XPU == kDGLROCM) && (val).bits == 16 && (val).code == kDGLFloat) { \
typedef __half FloatType; \ typedef __half FloatType; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ( \ } else if ( \
XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) { \ (XPU == kDGLCUDA || XPU == kDGLROCM) && (val).bits == 16 && (val).code == kDGLBfloat) { \
LOG(FATAL) << "bfloat16 requires CUDA >= 11.0"; \ LOG(FATAL) << "bfloat16 requires CUDA >= 11.0"; \
} else if ( \ } else if ( \
XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLFloat) { \ XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLFloat) { \
......
// !!! This is a file automatically generated by hipify!!!
/** /**
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* @file dgl/aten/spmat.h * @file dgl/aten/spmat.h
...@@ -10,7 +11,7 @@ ...@@ -10,7 +11,7 @@
#include <vector> #include <vector>
#include "../runtime/object.h" #include "../runtime/object.h"
#include "./types.h" #include "types.h"
namespace dgl { namespace dgl {
......
// !!! This is a file automatically generated by hipify!!!
/** /**
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* @file dgl/heterograph_interface.h * @file dgl/heterograph_interface.h
...@@ -13,7 +14,7 @@ ...@@ -13,7 +14,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "./runtime/object.h" #include "runtime/object.h"
#include "array.h" #include "array.h"
#include "aten/spmat.h" #include "aten/spmat.h"
#include "aten/types.h" #include "aten/types.h"
......
// !!! This is a file automatically generated by hipify!!!
/** /**
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* @file dgl/aten/bcast.h * @file dgl/aten/bcast.h
...@@ -9,7 +10,7 @@ ...@@ -9,7 +10,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "./runtime/ndarray.h" #include "runtime/ndarray.h"
using namespace dgl::runtime; using namespace dgl::runtime;
namespace dgl { namespace dgl {
......
// !!! This is a file automatically generated by hipify!!!
/** /**
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* @file dgl/graph_interface.h * @file dgl/graph_interface.h
...@@ -12,7 +13,7 @@ ...@@ -12,7 +13,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "./runtime/object.h" #include "runtime/object.h"
#include "array.h" #include "array.h"
namespace dgl { namespace dgl {
......
// !!! This is a file automatically generated by hipify!!!
/** /**
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* @file dgl/aten/kernel.h * @file dgl/aten/kernel.h
...@@ -10,8 +11,8 @@ ...@@ -10,8 +11,8 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "./base_heterograph.h" #include "base_heterograph.h"
#include "./bcast.h" #include "bcast.h"
#include "array.h" #include "array.h"
namespace dgl { namespace dgl {
......
// !!! This is a file automatically generated by hipify!!!
/** /**
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* @file dgl/nodeflow.h * @file dgl/nodeflow.h
...@@ -10,7 +11,7 @@ ...@@ -10,7 +11,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "./runtime/object.h" #include "runtime/object.h"
#include "graph_interface.h" #include "graph_interface.h"
namespace dgl { namespace dgl {
......
// !!! This is a file automatically generated by hipify!!!
/** /**
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* @file packed_func_ext.h * @file packed_func_ext.h
...@@ -12,9 +13,9 @@ ...@@ -12,9 +13,9 @@
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include "./runtime/container.h" #include "runtime/container.h"
#include "./runtime/object.h" #include "runtime/object.h"
#include "./runtime/packed_func.h" #include "runtime/packed_func.h"
namespace dgl { namespace dgl {
namespace runtime { namespace runtime {
......
// !!! This is a file automatically generated by hipify!!!
/** /**
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* @file dgl/runtime/c_object_api.h * @file dgl/runtime/c_object_api.h
...@@ -10,7 +11,7 @@ ...@@ -10,7 +11,7 @@
#ifndef DGL_RUNTIME_C_OBJECT_API_H_ #ifndef DGL_RUNTIME_C_OBJECT_API_H_
#define DGL_RUNTIME_C_OBJECT_API_H_ #define DGL_RUNTIME_C_OBJECT_API_H_
#include "./c_runtime_api.h" #include "c_runtime_api.h"
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
// DGL version // DGL version
#define DGL_VERSION "2.2.1" #define DGL_VERSION "2.2.1"
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
...@@ -55,7 +56,8 @@ typedef enum { ...@@ -55,7 +56,8 @@ typedef enum {
/** @brief CPU device */ /** @brief CPU device */
kDGLCPU = 1, kDGLCPU = 1,
/** @brief CUDA GPU device */ /** @brief CUDA GPU device */
kDGLCUDA = 2, kDGLCUDA = 10,
kDGLROCM = 2,
// add more devices once supported // add more devices once supported
} DGLDeviceType; } DGLDeviceType;
......
// !!! This is a file automatically generated by hipify!!!
/** /**
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* @file dgl/runtime/device_api.h * @file dgl/runtime/device_api.h
...@@ -174,7 +175,7 @@ class DeviceAPI { ...@@ -174,7 +175,7 @@ class DeviceAPI {
DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst); DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst);
/** /**
* @brief Pin host memory using cudaHostRegister(). * @brief Pin host memory using hipHostRegister().
* *
* @param ptr The host memory pointer to be pinned. * @param ptr The host memory pointer to be pinned.
* @param nbytes The size to be pinned. * @param nbytes The size to be pinned.
...@@ -183,7 +184,7 @@ class DeviceAPI { ...@@ -183,7 +184,7 @@ class DeviceAPI {
DGL_DLL virtual bool PinData(void* ptr, size_t nbytes); DGL_DLL virtual bool PinData(void* ptr, size_t nbytes);
/** /**
* @brief Unpin host memory using cudaHostUnregister(). * @brief Unpin host memory using hipHostUnregister().
* *
* @param ptr The host memory pointer to be unpinned. * @param ptr The host memory pointer to be unpinned.
*/ */
...@@ -203,7 +204,7 @@ class DeviceAPI { ...@@ -203,7 +204,7 @@ class DeviceAPI {
/** /**
* @brief 'Deallocate' the pinned memory from PyTorch CachingHostAllocator. * @brief 'Deallocate' the pinned memory from PyTorch CachingHostAllocator.
* @note It avoids unnecessary cudaFreeHost calls and puts the memory * @note It avoids unnecessary hipHostFree calls and puts the memory
* block into CachingHostAllocator's free list. * block into CachingHostAllocator's free list.
* @param deleter Pointer to the deleter function from PyTorch's * @param deleter Pointer to the deleter function from PyTorch's
* CachingHostAllocator. * CachingHostAllocator.
......
// !!! This is a file automatically generated by hipify!!!
/** /**
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* @file dgl/runtime/module.h * @file dgl/runtime/module.h
......
// !!! This is a file automatically generated by hipify!!!
/** /**
* Copyright (c) 2017-2022 by Contributors * Copyright (c) 2017-2022 by Contributors
* @file dgl/runtime/ndarray.h * @file dgl/runtime/ndarray.h
...@@ -18,13 +19,20 @@ ...@@ -18,13 +19,20 @@
#include "shared_mem.h" #include "shared_mem.h"
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
#include <cuda_runtime.h> #include <hip/hip_runtime.h>
#define BF16_ENABLED (defined(CUDART_VERSION) && CUDART_VERSION >= 11000) // #define BF16_ENABLED (defined(DTKRT_VERSION) && DTKRT_VERSION >= 11000)
#if defined(DTKRT_VERSION)
#define DTKRT_VERSION_CHECK (DTKRT_VERSION >= 11000)
#else
#define DTKRT_VERSION_CHECK 0
#endif
#include <cuda_fp16.h> #define BF16_ENABLED DTKRT_VERSION_CHECK
#include <hip/hip_fp16.h>
#if BF16_ENABLED #if BF16_ENABLED
#include <cuda_bf16.h> #include <hip/hip_bf16.h>
#endif // BF16_ENABLED #endif // BF16_ENABLED
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
...@@ -60,7 +68,7 @@ GEN_DGLDATATYPETRAITS_FOR(uint64_t, kDGLInt, 64); ...@@ -60,7 +68,7 @@ GEN_DGLDATATYPETRAITS_FOR(uint64_t, kDGLInt, 64);
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
GEN_DGLDATATYPETRAITS_FOR(__half, kDGLFloat, 16); GEN_DGLDATATYPETRAITS_FOR(__half, kDGLFloat, 16);
#if BF16_ENABLED #if BF16_ENABLED
GEN_DGLDATATYPETRAITS_FOR(__nv_bfloat16, kDGLBfloat, 16); GEN_DGLDATATYPETRAITS_FOR(__hip_bfloat16, kDGLBfloat, 16);
#endif // BF16_ENABLED #endif // BF16_ENABLED
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
GEN_DGLDATATYPETRAITS_FOR(float, kDGLFloat, 32); GEN_DGLDATATYPETRAITS_FOR(float, kDGLFloat, 32);
...@@ -185,7 +193,7 @@ class NDArray { ...@@ -185,7 +193,7 @@ class NDArray {
* CachingHostAllocator for allocating pinned memory and copying data * CachingHostAllocator for allocating pinned memory and copying data
* from the current NDAarray. As a result, PyTorch is responsible for * from the current NDAarray. As a result, PyTorch is responsible for
* managing the lifecycle of the returned NDArray, including deciding * managing the lifecycle of the returned NDArray, including deciding
* when to flush the data for reuse or call cudaFreeHost. The current * when to flush the data for reuse or call hipHostFree. The current
* context must be kDGLCPU, otherwise, an error will be thrown. * context must be kDGLCPU, otherwise, an error will be thrown.
*/ */
inline NDArray PinMemory(); inline NDArray PinMemory();
...@@ -194,7 +202,7 @@ class NDArray { ...@@ -194,7 +202,7 @@ class NDArray {
* @brief In-place method to pin the current array by calling PinContainer * @brief In-place method to pin the current array by calling PinContainer
* on the underlying NDArray:Container. * on the underlying NDArray:Container.
* @note This is an in-place method that flags the memory as page-locked by * @note This is an in-place method that flags the memory as page-locked by
* utilizing cudaHostRegister at the underlying level to pin the current * utilizing hipHostRegister at the underlying level to pin the current
* instance of NDArray. The current context must be kDGLCPU, otherwise, * instance of NDArray. The current context must be kDGLCPU, otherwise,
* an error will be thrown. * an error will be thrown.
*/ */
...@@ -523,7 +531,7 @@ inline void NDArray::CopyFrom(const NDArray& other) { ...@@ -523,7 +531,7 @@ inline void NDArray::CopyFrom(const NDArray& other) {
// Pinned by PyTorch // Pinned by PyTorch
if (cpu_data->pinned_by_pytorch_) { if (cpu_data->pinned_by_pytorch_) {
// To ensure correct behavior, the event must be recorded after // To ensure correct behavior, the event must be recorded after
// cudaMemcpyAsync as long as the memory is pinned by PyTorch. // hipMemcpyAsync as long as the memory is pinned by PyTorch.
void* pytorch_ctx = cpu_data->pytorch_ctx_; void* pytorch_ctx = cpu_data->pytorch_ctx_;
RecordedCopyFromTo( RecordedCopyFromTo(
&(other.data_->dl_tensor), &(data_->dl_tensor), pytorch_ctx); &(other.data_->dl_tensor), &(data_->dl_tensor), pytorch_ctx);
...@@ -549,7 +557,7 @@ inline void NDArray::CopyTo(const NDArray& other) const { ...@@ -549,7 +557,7 @@ inline void NDArray::CopyTo(const NDArray& other) const {
// pinned by PyTorch // pinned by PyTorch
if (cpu_data->pinned_by_pytorch_) { if (cpu_data->pinned_by_pytorch_) {
// To ensure correct behavior, the event must be recorded after // To ensure correct behavior, the event must be recorded after
// cudaMemcpyAsync as long as the memory is pinned by PyTorch. // hipMemcpyAsync as long as the memory is pinned by PyTorch.
void* pytorch_ctx = cpu_data->pytorch_ctx_; void* pytorch_ctx = cpu_data->pytorch_ctx_;
RecordedCopyFromTo( RecordedCopyFromTo(
&(data_->dl_tensor), &(other.data_->dl_tensor), pytorch_ctx); &(data_->dl_tensor), &(other.data_->dl_tensor), pytorch_ctx);
...@@ -716,6 +724,8 @@ inline const char* DeviceTypeCode2Str(DGLDeviceType device_type) { ...@@ -716,6 +724,8 @@ inline const char* DeviceTypeCode2Str(DGLDeviceType device_type) {
return "cpu"; return "cpu";
case kDGLCUDA: case kDGLCUDA:
return "cuda"; return "cuda";
case kDGLROCM:
return "cuda";
default: default:
LOG(FATAL) << "Unsupported device type code=" LOG(FATAL) << "Unsupported device type code="
<< static_cast<int>(device_type); << static_cast<int>(device_type);
...@@ -871,8 +881,11 @@ inline std::ostream& operator<<(std::ostream& os, DGLDataType t) { ...@@ -871,8 +881,11 @@ inline std::ostream& operator<<(std::ostream& os, DGLDataType t) {
/** @brief Check whether two device contexts are the same.*/ /** @brief Check whether two device contexts are the same.*/
inline bool operator==(const DGLContext& ctx1, const DGLContext& ctx2) { inline bool operator==(const DGLContext& ctx1, const DGLContext& ctx2) {
return ctx1.device_type == ctx2.device_type && // printf("**************** debug compare DGLContext, %d, %d\n",ctx1.device_type,ctx2.device_type);
ctx1.device_id == ctx2.device_id; int ct1=ctx1.device_type==10?2:ctx1.device_type;
int ct2=ctx2.device_type==10?2:ctx2.device_type;
return ct1 == ct2 &&
int(ctx1.device_id) == int(ctx2.device_id);
} }
/** @brief Check whether two device contexts are different.*/ /** @brief Check whether two device contexts are different.*/
......
// !!! This is a file automatically generated by hipify!!!
/** /**
* Copyright (c) 2020-2022 by Contributors * Copyright (c) 2020-2022 by Contributors
* @file array/tensordispatch.h * @file array/tensordispatch.h
...@@ -34,7 +35,7 @@ ...@@ -34,7 +35,7 @@
#include <windows.h> #include <windows.h>
#endif // WIN32 #endif // WIN32
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
#include <cuda_runtime.h> #include <hip/hip_runtime.h>
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
#include "ndarray.h" #include "ndarray.h"
...@@ -97,14 +98,14 @@ class TensorDispatcher { ...@@ -97,14 +98,14 @@ class TensorDispatcher {
* Used in CUDADeviceAPI::AllocWorkspace(). * Used in CUDADeviceAPI::AllocWorkspace().
* *
* @note THCCachingAllocator specify the device to allocate on * @note THCCachingAllocator specify the device to allocate on
* via cudaGetDevice(). Make sure to call cudaSetDevice() * via hipGetDevice(). Make sure to call hipSetDevice()
* before invoking this function. * before invoking this function.
* *
* @param nbytes The size to be allocated. * @param nbytes The size to be allocated.
* @param stream The stream to be allocated on. * @param stream The stream to be allocated on.
* @return Pointer to the allocated memory. * @return Pointer to the allocated memory.
*/ */
inline void* CUDAAllocWorkspace(size_t nbytes, cudaStream_t stream) { inline void* CUDAAllocWorkspace(size_t nbytes, hipStream_t stream) {
auto entry = entrypoints_[Op::kCUDARawAlloc]; auto entry = entrypoints_[Op::kCUDARawAlloc];
return FUNCCAST(tensoradapter::CUDARawAlloc, entry)(nbytes, stream); return FUNCCAST(tensoradapter::CUDARawAlloc, entry)(nbytes, stream);
} }
...@@ -122,15 +123,15 @@ class TensorDispatcher { ...@@ -122,15 +123,15 @@ class TensorDispatcher {
/** /**
* @brief Find the current PyTorch CUDA stream * @brief Find the current PyTorch CUDA stream
* Used in runtime::getCurrentCUDAStream(). * Used in runtime::getCurrentHIPStreamMasqueradingAsCUDA().
* *
* @note PyTorch pre-allocates/sets the current CUDA stream * @note PyTorch pre-allocates/sets the current CUDA stream
* on current device via cudaGetDevice(). Make sure to call cudaSetDevice() * on current device via hipGetDevice(). Make sure to call hipSetDevice()
* before invoking this function. * before invoking this function.
* *
* @return cudaStream_t stream handle * @return hipStream_t stream handle
*/ */
inline cudaStream_t CUDAGetCurrentStream() { inline hipStream_t CUDAGetCurrentStream() {
auto entry = entrypoints_[Op::kCUDACurrentStream]; auto entry = entrypoints_[Op::kCUDACurrentStream];
return FUNCCAST(tensoradapter::CUDACurrentStream, entry)(); return FUNCCAST(tensoradapter::CUDACurrentStream, entry)();
} }
...@@ -183,7 +184,7 @@ class TensorDispatcher { ...@@ -183,7 +184,7 @@ class TensorDispatcher {
* @param device_id Device of the tensor. * @param device_id Device of the tensor.
*/ */
inline void CUDARecordHostAlloc( inline void CUDARecordHostAlloc(
void* data, void* ctx, cudaStream_t stream, int device_id) { void* data, void* ctx, hipStream_t stream, int device_id) {
auto entry = entrypoints_[Op::kCUDARecordHostAlloc]; auto entry = entrypoints_[Op::kCUDARecordHostAlloc];
auto recorded_alloc = FUNCCAST(tensoradapter::CUDARecordHostAlloc, entry); auto recorded_alloc = FUNCCAST(tensoradapter::CUDARecordHostAlloc, entry);
recorded_alloc(data, ctx, stream, device_id); recorded_alloc(data, ctx, stream, device_id);
...@@ -212,7 +213,7 @@ class TensorDispatcher { ...@@ -212,7 +213,7 @@ class TensorDispatcher {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
auto entry = entrypoints_[Op::kRecordStream]; auto entry = entrypoints_[Op::kRecordStream];
FUNCCAST(tensoradapter::RecordStream, entry) FUNCCAST(tensoradapter::RecordStream, entry)
(ptr, static_cast<cudaStream_t>(stream), device_id); (ptr, static_cast<hipStream_t>(stream), device_id);
#endif #endif
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment