Commit aaaecbc9 authored by lisj's avatar lisj
Browse files

处理kDLGPU为kDLROCM

parent c454d419
......@@ -219,7 +219,7 @@ std::pair<IdArray, IdArray> RandomWalkUniform(
dim3 grid((num_seeds + TILE_SIZE - 1) / TILE_SIZE);
const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);
ATEN_FLOAT_TYPE_SWITCH(restart_prob->dtype, FloatType, "random walk GPU kernel", {
CHECK(restart_prob->ctx.device_type == kDLGPU) << "restart prob should be in GPU.";
CHECK(restart_prob->ctx.device_type == kDLROCM) << "restart prob should be in GPU.";
CHECK(restart_prob->ndim == 1) << "restart prob dimension should be 1.";
const FloatType *restart_prob_data = restart_prob.Ptr<FloatType>();
const int64_t restart_prob_size = restart_prob->shape[0];
......@@ -350,7 +350,7 @@ std::pair<IdArray, IdArray> RandomWalkBiased(
dim3 block(256);
dim3 grid((num_seeds + TILE_SIZE - 1) / TILE_SIZE);
const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);
CHECK(restart_prob->ctx.device_type == kDLGPU) << "restart prob should be in GPU.";
CHECK(restart_prob->ctx.device_type == kDLROCM) << "restart prob should be in GPU.";
CHECK(restart_prob->ndim == 1) << "restart prob dimension should be 1.";
const FloatType *restart_prob_data = restart_prob.Ptr<FloatType>();
const int64_t restart_prob_size = restart_prob->shape[0];
......@@ -480,7 +480,7 @@ std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
const IdArray dst,
const int64_t num_samples_per_node,
const int64_t k) {
CHECK(src->ctx.device_type == kDLGPU) <<
CHECK(src->ctx.device_type == kDLROCM) <<
"IdArray needs be on GPU!";
const IdxType* src_data = src.Ptr<IdxType>();
const IdxType* dst_data = dst.Ptr<IdxType>();
......@@ -496,27 +496,27 @@ std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
}
template
std::pair<IdArray, IdArray> RandomWalk<kDLGPU, int32_t>(
std::pair<IdArray, IdArray> RandomWalk<kDLROCM, int32_t>(
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob);
template
std::pair<IdArray, IdArray> RandomWalk<kDLGPU, int64_t>(
std::pair<IdArray, IdArray> RandomWalk<kDLROCM, int64_t>(
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob);
template
std::pair<IdArray, IdArray> RandomWalkWithRestart<kDLGPU, int32_t>(
std::pair<IdArray, IdArray> RandomWalkWithRestart<kDLROCM, int32_t>(
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob,
double restart_prob);
template
std::pair<IdArray, IdArray> RandomWalkWithRestart<kDLGPU, int64_t>(
std::pair<IdArray, IdArray> RandomWalkWithRestart<kDLROCM, int64_t>(
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
......@@ -524,14 +524,14 @@ std::pair<IdArray, IdArray> RandomWalkWithRestart<kDLGPU, int64_t>(
double restart_prob);
template
std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart<kDLGPU, int32_t>(
std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart<kDLROCM, int32_t>(
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob,
FloatArray restart_prob);
template
std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart<kDLGPU, int64_t>(
std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart<kDLROCM, int64_t>(
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
......@@ -539,13 +539,13 @@ std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart<kDLGPU, int64_t>(
FloatArray restart_prob);
template
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors<kDLGPU, int32_t>(
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors<kDLROCM, int32_t>(
const IdArray src,
const IdArray dst,
const int64_t num_samples_per_node,
const int64_t k);
template
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors<kDLGPU, int64_t>(
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors<kDLROCM, int64_t>(
const IdArray src,
const IdArray dst,
const int64_t num_samples_per_node,
......
......@@ -36,7 +36,7 @@ void CheckRandomWalkInputs(
// CHECK_SAME_CONTEXT(seeds, metapath);
if (hg->IsPinned()) {
CHECK_EQ(seeds->ctx.device_type, kDLGPU) << "Expected seeds (" << seeds->ctx << ")" \
CHECK_EQ(seeds->ctx.device_type, kDLROCM) << "Expected seeds (" << seeds->ctx << ")" \
<< " to be on the GPU when the graph is pinned.";
} else if (hg->Context() != seeds->ctx) {
LOG(FATAL) << "Expected seeds (" << seeds->ctx << ")" << " to have the same " \
......
......@@ -70,7 +70,7 @@ void BuildNodeMaps(
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
const IdArray& nodes = input_nodes[ntype];
if (nodes->shape[0] > 0) {
CHECK_EQ(nodes->ctx.device_type, kDLGPU);
CHECK_EQ(nodes->ctx.device_type, kDLROCM);
node_maps->LhsHashTable(ntype).FillWithDuplicates(
nodes.Ptr<IdType>(),
nodes->shape[0],
......@@ -92,7 +92,7 @@ CompactGraphsGPU(
auto device = runtime::DeviceAPI::Get(ctx);
hipStream_t stream = runtime::getCurrentCUDAStream();
CHECK_EQ(ctx.device_type, kDLGPU);
CHECK_EQ(ctx.device_type, kDLROCM);
// Step 1: Collect the nodes that has connections for each type.
const uint64_t num_ntypes = graphs[0]->NumVertexTypes();
......@@ -255,7 +255,7 @@ CompactGraphsGPU(
template<>
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>
CompactGraphs<kDLGPU, int32_t>(
CompactGraphs<kDLROCM, int32_t>(
const std::vector<HeteroGraphPtr> &graphs,
const std::vector<IdArray> &always_preserve) {
return CompactGraphsGPU<int32_t>(graphs, always_preserve);
......@@ -263,7 +263,7 @@ CompactGraphs<kDLGPU, int32_t>(
template<>
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>
CompactGraphs<kDLGPU, int64_t>(
CompactGraphs<kDLROCM, int64_t>(
const std::vector<HeteroGraphPtr> &graphs,
const std::vector<IdArray> &always_preserve) {
return CompactGraphsGPU<int64_t>(graphs, always_preserve);
......
......@@ -82,7 +82,7 @@ class DeviceNodeMapMaker {
for (int64_t ntype = 0; ntype < lhs_num_ntypes; ++ntype) {
const IdArray& nodes = lhs_nodes[ntype];
if (nodes->shape[0] > 0) {
CHECK_EQ(nodes->ctx.device_type, kDLGPU);
CHECK_EQ(nodes->ctx.device_type, kDLROCM);
node_maps->LhsHashTable(ntype).FillWithDuplicates(
nodes.Ptr<IdType>(),
nodes->shape[0],
......@@ -127,7 +127,7 @@ class DeviceNodeMapMaker {
for (int64_t ntype = 0; ntype < lhs_num_ntypes; ++ntype) {
const IdArray& nodes = lhs_nodes[ntype];
if (nodes->shape[0] > 0) {
CHECK_EQ(nodes->ctx.device_type, kDLGPU);
CHECK_EQ(nodes->ctx.device_type, kDLROCM);
node_maps->LhsHashTable(ntype).FillWithUnique(
nodes.Ptr<IdType>(),
nodes->shape[0],
......@@ -154,7 +154,7 @@ class DeviceNodeMapMaker {
// Since partial specialization is not allowed for functions, use this as an
// intermediate for ToBlock where XPU = kDLGPU.
// intermediate for ToBlock where XPU = kDLROCM.
template<typename IdType>
std::tuple<HeteroGraphPtr, std::vector<IdArray>>
ToBlockGPU(
......@@ -170,7 +170,7 @@ ToBlockGPU(
auto device = runtime::DeviceAPI::Get(ctx);
hipStream_t stream = runtime::getCurrentCUDAStream();
CHECK_EQ(ctx.device_type, kDLGPU);
CHECK_EQ(ctx.device_type, kDLROCM);
for (const auto& nodes : rhs_nodes) {
CHECK_EQ(ctx.device_type, nodes->ctx.device_type);
}
......@@ -383,7 +383,7 @@ ToBlockGPU(
// functions are the same.
// Using template<> fails to export the symbols.
std::tuple<HeteroGraphPtr, std::vector<IdArray>>
// ToBlock<kDLGPU, int32_t>
// ToBlock<kDLROCM, int32_t>
ToBlockGPU32(
HeteroGraphPtr graph,
const std::vector<IdArray> &rhs_nodes,
......@@ -393,7 +393,7 @@ ToBlockGPU32(
}
std::tuple<HeteroGraphPtr, std::vector<IdArray>>
// ToBlock<kDLGPU, int64_t>
// ToBlock<kDLROCM, int64_t>
ToBlockGPU64(
HeteroGraphPtr graph,
const std::vector<IdArray> &rhs_nodes,
......
......@@ -923,36 +923,36 @@ void NNDescent(const NDArray& points, const IdArray& offsets,
device->FreeWorkspace(ctx, sum_temp_storage);
}
template void KNN<kDLGPU, float, int32_t>(
template void KNN<kDLROCM, float, int32_t>(
const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm);
template void KNN<kDLGPU, float, int64_t>(
template void KNN<kDLROCM, float, int64_t>(
const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm);
template void KNN<kDLGPU, double, int32_t>(
template void KNN<kDLROCM, double, int32_t>(
const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm);
template void KNN<kDLGPU, double, int64_t>(
template void KNN<kDLROCM, double, int64_t>(
const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm);
template void NNDescent<kDLGPU, float, int32_t>(
template void NNDescent<kDLROCM, float, int32_t>(
const NDArray& points, const IdArray& offsets,
IdArray result, const int k, const int num_iters,
const int num_candidates, const double delta);
template void NNDescent<kDLGPU, float, int64_t>(
template void NNDescent<kDLROCM, float, int64_t>(
const NDArray& points, const IdArray& offsets,
IdArray result, const int k, const int num_iters,
const int num_candidates, const double delta);
template void NNDescent<kDLGPU, double, int32_t>(
template void NNDescent<kDLROCM, double, int32_t>(
const NDArray& points, const IdArray& offsets,
IdArray result, const int k, const int num_iters,
const int num_candidates, const double delta);
template void NNDescent<kDLGPU, double, int64_t>(
template void NNDescent<kDLROCM, double, int64_t>(
const NDArray& points, const IdArray& offsets,
IdArray result, const int k, const int num_iters,
const int num_candidates, const double delta);
......
......@@ -172,7 +172,7 @@ ToBlockGPU64(HeteroGraphPtr, const std::vector<IdArray>&, bool, std::vector<IdAr
template<>
std::tuple<HeteroGraphPtr, std::vector<IdArray>>
ToBlock<kDLGPU, int32_t>(HeteroGraphPtr graph,
ToBlock<kDLROCM, int32_t>(HeteroGraphPtr graph,
const std::vector<IdArray> &rhs_nodes,
bool include_rhs_in_lhs,
std::vector<IdArray>* const lhs_nodes) {
......@@ -181,7 +181,7 @@ ToBlock<kDLGPU, int32_t>(HeteroGraphPtr graph,
template<>
std::tuple<HeteroGraphPtr, std::vector<IdArray>>
ToBlock<kDLGPU, int64_t>(HeteroGraphPtr graph,
ToBlock<kDLROCM, int64_t>(HeteroGraphPtr graph,
const std::vector<IdArray> &rhs_nodes,
bool include_rhs_in_lhs,
std::vector<IdArray>* const lhs_nodes) {
......
......@@ -214,7 +214,7 @@ class UnitGraph : public BaseHeteroGraph {
* \note The graph will be pinned inplace. Behavior depends on the current context,
* kDLCPU: will be pinned;
* IsPinned: directly return;
* kDLGPU: invalid, will throw an error.
* kDLROCM: invalid, will throw an error.
* The context check is deferred to pinning the NDArray.
*/
void PinMemory_() override;
......
......@@ -377,12 +377,12 @@ GeneratePermutationFromRemainder(
template std::pair<IdArray, IdArray>
GeneratePermutationFromRemainder<kDLGPU, int32_t>(
GeneratePermutationFromRemainder<kDLROCM, int32_t>(
int64_t array_size,
int num_parts,
IdArray in_idx);
template std::pair<IdArray, IdArray>
GeneratePermutationFromRemainder<kDLGPU, int64_t>(
GeneratePermutationFromRemainder<kDLROCM, int64_t>(
int64_t array_size,
int num_parts,
IdArray in_idx);
......@@ -421,11 +421,11 @@ IdArray MapToLocalFromRemainder(
}
template IdArray
MapToLocalFromRemainder<kDLGPU, int32_t>(
MapToLocalFromRemainder<kDLROCM, int32_t>(
int num_parts,
IdArray in_idx);
template IdArray
MapToLocalFromRemainder<kDLGPU, int64_t>(
MapToLocalFromRemainder<kDLROCM, int64_t>(
int num_parts,
IdArray in_idx);
......@@ -469,12 +469,12 @@ IdArray MapToGlobalFromRemainder(
}
template IdArray
MapToGlobalFromRemainder<kDLGPU, int32_t>(
MapToGlobalFromRemainder<kDLROCM, int32_t>(
int num_parts,
IdArray in_idx,
int part_id);
template IdArray
MapToGlobalFromRemainder<kDLGPU, int64_t>(
MapToGlobalFromRemainder<kDLROCM, int64_t>(
int num_parts,
IdArray in_idx,
int part_id);
......@@ -599,25 +599,25 @@ GeneratePermutationFromRange(
template std::pair<IdArray, IdArray>
GeneratePermutationFromRange<kDLGPU, int32_t, int32_t>(
GeneratePermutationFromRange<kDLROCM, int32_t, int32_t>(
int64_t array_size,
int num_parts,
IdArray range,
IdArray in_idx);
template std::pair<IdArray, IdArray>
GeneratePermutationFromRange<kDLGPU, int64_t, int32_t>(
GeneratePermutationFromRange<kDLROCM, int64_t, int32_t>(
int64_t array_size,
int num_parts,
IdArray range,
IdArray in_idx);
template std::pair<IdArray, IdArray>
GeneratePermutationFromRange<kDLGPU, int32_t, int64_t>(
GeneratePermutationFromRange<kDLROCM, int32_t, int64_t>(
int64_t array_size,
int num_parts,
IdArray range,
IdArray in_idx);
template std::pair<IdArray, IdArray>
GeneratePermutationFromRange<kDLGPU, int64_t, int64_t>(
GeneratePermutationFromRange<kDLROCM, int64_t, int64_t>(
int64_t array_size,
int num_parts,
IdArray range,
......@@ -658,22 +658,22 @@ IdArray MapToLocalFromRange(
}
template IdArray
MapToLocalFromRange<kDLGPU, int32_t, int32_t>(
MapToLocalFromRange<kDLROCM, int32_t, int32_t>(
int num_parts,
IdArray range,
IdArray in_idx);
template IdArray
MapToLocalFromRange<kDLGPU, int64_t, int32_t>(
MapToLocalFromRange<kDLROCM, int64_t, int32_t>(
int num_parts,
IdArray range,
IdArray in_idx);
template IdArray
MapToLocalFromRange<kDLGPU, int32_t, int64_t>(
MapToLocalFromRange<kDLROCM, int32_t, int64_t>(
int num_parts,
IdArray range,
IdArray in_idx);
template IdArray
MapToLocalFromRange<kDLGPU, int64_t, int64_t>(
MapToLocalFromRange<kDLROCM, int64_t, int64_t>(
int num_parts,
IdArray range,
IdArray in_idx);
......@@ -721,25 +721,25 @@ IdArray MapToGlobalFromRange(
}
template IdArray
MapToGlobalFromRange<kDLGPU, int32_t, int32_t>(
MapToGlobalFromRange<kDLROCM, int32_t, int32_t>(
int num_parts,
IdArray range,
IdArray in_idx,
int part_id);
template IdArray
MapToGlobalFromRange<kDLGPU, int64_t, int32_t>(
MapToGlobalFromRange<kDLROCM, int64_t, int32_t>(
int num_parts,
IdArray range,
IdArray in_idx,
int part_id);
template IdArray
MapToGlobalFromRange<kDLGPU, int32_t, int64_t>(
MapToGlobalFromRange<kDLROCM, int32_t, int64_t>(
int num_parts,
IdArray range,
IdArray in_idx,
int part_id);
template IdArray
MapToGlobalFromRange<kDLGPU, int64_t, int64_t>(
MapToGlobalFromRange<kDLROCM, int64_t, int64_t>(
int num_parts,
IdArray range,
IdArray in_idx,
......
......@@ -46,9 +46,9 @@ class RemainderPartition : public NDArrayPartition {
IdArray in_idx) const override {
#ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx;
if (ctx.device_type == kDLGPU) {
if (ctx.device_type == kDLROCM) {
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
return impl::GeneratePermutationFromRemainder<kDLGPU, IdType>(
return impl::GeneratePermutationFromRemainder<kDLROCM, IdType>(
ArraySize(), NumParts(), in_idx);
});
}
......@@ -64,9 +64,9 @@ class RemainderPartition : public NDArrayPartition {
IdArray in_idx) const override {
#ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx;
if (ctx.device_type == kDLGPU) {
if (ctx.device_type == kDLROCM) {
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
return impl::MapToLocalFromRemainder<kDLGPU, IdType>(
return impl::MapToLocalFromRemainder<kDLROCM, IdType>(
NumParts(), in_idx);
});
}
......@@ -83,9 +83,9 @@ class RemainderPartition : public NDArrayPartition {
const int part_id) const override {
#ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx;
if (ctx.device_type == kDLGPU) {
if (ctx.device_type == kDLROCM) {
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
return impl::MapToGlobalFromRemainder<kDLGPU, IdType>(
return impl::MapToGlobalFromRemainder<kDLROCM, IdType>(
NumParts(), in_idx, part_id);
});
}
......@@ -118,7 +118,7 @@ class RangePartition : public NDArrayPartition {
// have only one CPU context, and can safely copy the array to that.
range_cpu_(range.CopyTo(DGLContext{kDLCPU, 0})) {
auto ctx = range->ctx;
if (ctx.device_type != kDLGPU) {
if (ctx.device_type != kDLROCM) {
LOG(FATAL) << "The range for an NDArrayPartition is only supported "
" on GPUs. Transfer the range to the target device before "
"creating the partition.";
......@@ -130,7 +130,7 @@ class RangePartition : public NDArrayPartition {
IdArray in_idx) const override {
#ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx;
if (ctx.device_type == kDLGPU) {
if (ctx.device_type == kDLROCM) {
if (ctx.device_type != range_->ctx.device_type ||
ctx.device_id != range_->ctx.device_id) {
LOG(FATAL) << "The range for the NDArrayPartition and the input "
......@@ -138,7 +138,7 @@ class RangePartition : public NDArrayPartition {
}
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, {
return impl::GeneratePermutationFromRange<kDLGPU, IdType, RangeType>(
return impl::GeneratePermutationFromRange<kDLROCM, IdType, RangeType>(
ArraySize(), NumParts(), range_, in_idx);
});
});
......@@ -155,10 +155,10 @@ class RangePartition : public NDArrayPartition {
IdArray in_idx) const override {
#ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx;
if (ctx.device_type == kDLGPU) {
if (ctx.device_type == kDLROCM) {
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, {
return impl::MapToLocalFromRange<kDLGPU, IdType, RangeType>(
return impl::MapToLocalFromRange<kDLROCM, IdType, RangeType>(
NumParts(), range_, in_idx);
});
});
......@@ -176,10 +176,10 @@ class RangePartition : public NDArrayPartition {
const int part_id) const override {
#ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx;
if (ctx.device_type == kDLGPU) {
if (ctx.device_type == kDLROCM) {
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, {
return impl::MapToGlobalFromRange<kDLGPU, IdType, RangeType>(
return impl::MapToGlobalFromRange<kDLROCM, IdType, RangeType>(
NumParts(), range_, in_idx, part_id);
});
});
......
......@@ -29,7 +29,7 @@ DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed")
}
});
#ifdef DGL_USE_CUDA
if (DeviceAPI::Get(kDLGPU)->IsAvailable()) {
if (DeviceAPI::Get(kDLROCM)->IsAvailable()) {
auto* thr_entry = CUDAThreadEntry::ThreadLocal();
if (!thr_entry->curand_gen) {
CURAND_CALL(hiprandCreateGenerator(&thr_entry->curand_gen, HIPRAND_RNG_PSEUDO_DEFAULT));
......
......@@ -27,7 +27,7 @@ namespace runtime {
inline std::string DeviceName(int type) {
switch (type) {
case kDLCPU: return "cpu";
case kDLGPU: return "gpu";
case kDLROCM: return "gpu";
case kDLOpenCL: return "opencl";
case kDLSDAccel: return "sdaccel";
case kDLAOCL: return "aocl";
......
......@@ -141,7 +141,7 @@ class CUDADeviceAPI final : public DeviceAPI {
hipStream_t cu_stream = static_cast<hipStream_t>(stream);
from = static_cast<const char*>(from) + from_offset;
to = static_cast<char*>(to) + to_offset;
if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLGPU) {
if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLROCM) {
CUDA_CALL(hipSetDevice(ctx_from.device_id));
if (ctx_from.device_id == ctx_to.device_id) {
GPUCopy(from, to, size, hipMemcpyDeviceToDevice, cu_stream);
......@@ -150,10 +150,10 @@ class CUDADeviceAPI final : public DeviceAPI {
from, ctx_from.device_id,
size, cu_stream));
}
} else if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLCPU) {
} else if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLCPU) {
CUDA_CALL(hipSetDevice(ctx_from.device_id));
GPUCopy(from, to, size, hipMemcpyDeviceToHost, cu_stream);
} else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLGPU) {
} else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLROCM) {
CUDA_CALL(hipSetDevice(ctx_to.device_id));
GPUCopy(from, to, size, hipMemcpyHostToDevice, cu_stream);
} else {
......@@ -314,7 +314,7 @@ class CUDADeviceAPI final : public DeviceAPI {
typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore;
CUDAThreadEntry::CUDAThreadEntry()
: pool(kDLGPU, CUDADeviceAPI::Global()) {
: pool(kDLROCM, CUDADeviceAPI::Global()) {
}
CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() {
......
......@@ -618,7 +618,7 @@ void NCCLCommunicator::AllToAllV(
int dev_id;
CUDA_CALL(hipGetDevice(&dev_id));
DGLContext ctx{kDLGPU, dev_id};
DGLContext ctx{kDLROCM, dev_id};
auto device = runtime::DeviceAPI::Get(ctx);
auto dtype = DLDataTypeTraits<DType>::dtype;
......@@ -680,7 +680,7 @@ void NCCLCommunicator::AllToAll(
#else
int dev_id;
CUDA_CALL(hipGetDevice(&dev_id));
DGLContext ctx{kDLGPU, dev_id};
DGLContext ctx{kDLROCM, dev_id};
auto device = runtime::DeviceAPI::Get(ctx);
auto dtype = DLDataTypeTraits<IdType>::dtype;
......
......@@ -262,7 +262,7 @@ void NDArray::PinContainer(NDArray::Container* ptr) {
auto* tensor = &(ptr->dl_tensor);
CHECK_EQ(tensor->ctx.device_type, kDLCPU)
<< "Only NDArray on CPU can be pinned";
DeviceAPI::Get(kDLGPU)->PinData(tensor->data, GetDataSize(*tensor));
DeviceAPI::Get(kDLROCM)->PinData(tensor->data, GetDataSize(*tensor));
ptr->pinned_by_dgl_ = true;
}
......@@ -275,14 +275,14 @@ void NDArray::UnpinContainer(NDArray::Container* ptr) {
// 1. not pinned, do nothing
if (!container_is_pinned) return;
// 2. pinned by DGL, unpin it
DeviceAPI::Get(kDLGPU)->UnpinData(ptr->dl_tensor.data);
DeviceAPI::Get(kDLROCM)->UnpinData(ptr->dl_tensor.data);
ptr->pinned_by_dgl_ = false;
}
void NDArray::RecordStream(DGLArray* tensor, DGLStreamHandle stream) {
TensorDispatcher* td = TensorDispatcher::Global();
CHECK(td->IsAvailable()) << "RecordStream only works when TensorAdaptor is available.";
CHECK_EQ(tensor->ctx.device_type, kDLGPU)
CHECK_EQ(tensor->ctx.device_type, kDLROCM)
<< "RecordStream only works with GPU tensors.";
td->RecordStream(tensor->data, stream, tensor->ctx.device_id);
......@@ -353,7 +353,7 @@ bool NDArray::IsContainerPinned(NDArray::Container* ptr) {
if (tensor->ctx.device_type != kDLCPU)
return false;
// ... and CUDA device API is enabled, and the tensor is indeed in pinned memory.
auto device = DeviceAPI::Get(kDLGPU, true);
auto device = DeviceAPI::Get(kDLROCM, true);
return device && device->IsPinned(tensor->data);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment