Commit aaaecbc9 authored by lisj's avatar lisj
Browse files

处理kDLGPU为kDLROCM

parent c454d419
...@@ -219,7 +219,7 @@ std::pair<IdArray, IdArray> RandomWalkUniform( ...@@ -219,7 +219,7 @@ std::pair<IdArray, IdArray> RandomWalkUniform(
dim3 grid((num_seeds + TILE_SIZE - 1) / TILE_SIZE); dim3 grid((num_seeds + TILE_SIZE - 1) / TILE_SIZE);
const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000); const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);
ATEN_FLOAT_TYPE_SWITCH(restart_prob->dtype, FloatType, "random walk GPU kernel", { ATEN_FLOAT_TYPE_SWITCH(restart_prob->dtype, FloatType, "random walk GPU kernel", {
CHECK(restart_prob->ctx.device_type == kDLGPU) << "restart prob should be in GPU."; CHECK(restart_prob->ctx.device_type == kDLROCM) << "restart prob should be in GPU.";
CHECK(restart_prob->ndim == 1) << "restart prob dimension should be 1."; CHECK(restart_prob->ndim == 1) << "restart prob dimension should be 1.";
const FloatType *restart_prob_data = restart_prob.Ptr<FloatType>(); const FloatType *restart_prob_data = restart_prob.Ptr<FloatType>();
const int64_t restart_prob_size = restart_prob->shape[0]; const int64_t restart_prob_size = restart_prob->shape[0];
...@@ -350,7 +350,7 @@ std::pair<IdArray, IdArray> RandomWalkBiased( ...@@ -350,7 +350,7 @@ std::pair<IdArray, IdArray> RandomWalkBiased(
dim3 block(256); dim3 block(256);
dim3 grid((num_seeds + TILE_SIZE - 1) / TILE_SIZE); dim3 grid((num_seeds + TILE_SIZE - 1) / TILE_SIZE);
const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000); const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);
CHECK(restart_prob->ctx.device_type == kDLGPU) << "restart prob should be in GPU."; CHECK(restart_prob->ctx.device_type == kDLROCM) << "restart prob should be in GPU.";
CHECK(restart_prob->ndim == 1) << "restart prob dimension should be 1."; CHECK(restart_prob->ndim == 1) << "restart prob dimension should be 1.";
const FloatType *restart_prob_data = restart_prob.Ptr<FloatType>(); const FloatType *restart_prob_data = restart_prob.Ptr<FloatType>();
const int64_t restart_prob_size = restart_prob->shape[0]; const int64_t restart_prob_size = restart_prob->shape[0];
...@@ -480,7 +480,7 @@ std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors( ...@@ -480,7 +480,7 @@ std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
const IdArray dst, const IdArray dst,
const int64_t num_samples_per_node, const int64_t num_samples_per_node,
const int64_t k) { const int64_t k) {
CHECK(src->ctx.device_type == kDLGPU) << CHECK(src->ctx.device_type == kDLROCM) <<
"IdArray needs be on GPU!"; "IdArray needs be on GPU!";
const IdxType* src_data = src.Ptr<IdxType>(); const IdxType* src_data = src.Ptr<IdxType>();
const IdxType* dst_data = dst.Ptr<IdxType>(); const IdxType* dst_data = dst.Ptr<IdxType>();
...@@ -496,27 +496,27 @@ std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors( ...@@ -496,27 +496,27 @@ std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
} }
template template
std::pair<IdArray, IdArray> RandomWalk<kDLGPU, int32_t>( std::pair<IdArray, IdArray> RandomWalk<kDLROCM, int32_t>(
const HeteroGraphPtr hg, const HeteroGraphPtr hg,
const IdArray seeds, const IdArray seeds,
const TypeArray metapath, const TypeArray metapath,
const std::vector<FloatArray> &prob); const std::vector<FloatArray> &prob);
template template
std::pair<IdArray, IdArray> RandomWalk<kDLGPU, int64_t>( std::pair<IdArray, IdArray> RandomWalk<kDLROCM, int64_t>(
const HeteroGraphPtr hg, const HeteroGraphPtr hg,
const IdArray seeds, const IdArray seeds,
const TypeArray metapath, const TypeArray metapath,
const std::vector<FloatArray> &prob); const std::vector<FloatArray> &prob);
template template
std::pair<IdArray, IdArray> RandomWalkWithRestart<kDLGPU, int32_t>( std::pair<IdArray, IdArray> RandomWalkWithRestart<kDLROCM, int32_t>(
const HeteroGraphPtr hg, const HeteroGraphPtr hg,
const IdArray seeds, const IdArray seeds,
const TypeArray metapath, const TypeArray metapath,
const std::vector<FloatArray> &prob, const std::vector<FloatArray> &prob,
double restart_prob); double restart_prob);
template template
std::pair<IdArray, IdArray> RandomWalkWithRestart<kDLGPU, int64_t>( std::pair<IdArray, IdArray> RandomWalkWithRestart<kDLROCM, int64_t>(
const HeteroGraphPtr hg, const HeteroGraphPtr hg,
const IdArray seeds, const IdArray seeds,
const TypeArray metapath, const TypeArray metapath,
...@@ -524,14 +524,14 @@ std::pair<IdArray, IdArray> RandomWalkWithRestart<kDLGPU, int64_t>( ...@@ -524,14 +524,14 @@ std::pair<IdArray, IdArray> RandomWalkWithRestart<kDLGPU, int64_t>(
double restart_prob); double restart_prob);
template template
std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart<kDLGPU, int32_t>( std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart<kDLROCM, int32_t>(
const HeteroGraphPtr hg, const HeteroGraphPtr hg,
const IdArray seeds, const IdArray seeds,
const TypeArray metapath, const TypeArray metapath,
const std::vector<FloatArray> &prob, const std::vector<FloatArray> &prob,
FloatArray restart_prob); FloatArray restart_prob);
template template
std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart<kDLGPU, int64_t>( std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart<kDLROCM, int64_t>(
const HeteroGraphPtr hg, const HeteroGraphPtr hg,
const IdArray seeds, const IdArray seeds,
const TypeArray metapath, const TypeArray metapath,
...@@ -539,13 +539,13 @@ std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart<kDLGPU, int64_t>( ...@@ -539,13 +539,13 @@ std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart<kDLGPU, int64_t>(
FloatArray restart_prob); FloatArray restart_prob);
template template
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors<kDLGPU, int32_t>( std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors<kDLROCM, int32_t>(
const IdArray src, const IdArray src,
const IdArray dst, const IdArray dst,
const int64_t num_samples_per_node, const int64_t num_samples_per_node,
const int64_t k); const int64_t k);
template template
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors<kDLGPU, int64_t>( std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors<kDLROCM, int64_t>(
const IdArray src, const IdArray src,
const IdArray dst, const IdArray dst,
const int64_t num_samples_per_node, const int64_t num_samples_per_node,
......
...@@ -36,7 +36,7 @@ void CheckRandomWalkInputs( ...@@ -36,7 +36,7 @@ void CheckRandomWalkInputs(
// CHECK_SAME_CONTEXT(seeds, metapath); // CHECK_SAME_CONTEXT(seeds, metapath);
if (hg->IsPinned()) { if (hg->IsPinned()) {
CHECK_EQ(seeds->ctx.device_type, kDLGPU) << "Expected seeds (" << seeds->ctx << ")" \ CHECK_EQ(seeds->ctx.device_type, kDLROCM) << "Expected seeds (" << seeds->ctx << ")" \
<< " to be on the GPU when the graph is pinned."; << " to be on the GPU when the graph is pinned.";
} else if (hg->Context() != seeds->ctx) { } else if (hg->Context() != seeds->ctx) {
LOG(FATAL) << "Expected seeds (" << seeds->ctx << ")" << " to have the same " \ LOG(FATAL) << "Expected seeds (" << seeds->ctx << ")" << " to have the same " \
......
...@@ -70,7 +70,7 @@ void BuildNodeMaps( ...@@ -70,7 +70,7 @@ void BuildNodeMaps(
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) { for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
const IdArray& nodes = input_nodes[ntype]; const IdArray& nodes = input_nodes[ntype];
if (nodes->shape[0] > 0) { if (nodes->shape[0] > 0) {
CHECK_EQ(nodes->ctx.device_type, kDLGPU); CHECK_EQ(nodes->ctx.device_type, kDLROCM);
node_maps->LhsHashTable(ntype).FillWithDuplicates( node_maps->LhsHashTable(ntype).FillWithDuplicates(
nodes.Ptr<IdType>(), nodes.Ptr<IdType>(),
nodes->shape[0], nodes->shape[0],
...@@ -92,7 +92,7 @@ CompactGraphsGPU( ...@@ -92,7 +92,7 @@ CompactGraphsGPU(
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
hipStream_t stream = runtime::getCurrentCUDAStream(); hipStream_t stream = runtime::getCurrentCUDAStream();
CHECK_EQ(ctx.device_type, kDLGPU); CHECK_EQ(ctx.device_type, kDLROCM);
// Step 1: Collect the nodes that has connections for each type. // Step 1: Collect the nodes that has connections for each type.
const uint64_t num_ntypes = graphs[0]->NumVertexTypes(); const uint64_t num_ntypes = graphs[0]->NumVertexTypes();
...@@ -255,7 +255,7 @@ CompactGraphsGPU( ...@@ -255,7 +255,7 @@ CompactGraphsGPU(
template<> template<>
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>
CompactGraphs<kDLGPU, int32_t>( CompactGraphs<kDLROCM, int32_t>(
const std::vector<HeteroGraphPtr> &graphs, const std::vector<HeteroGraphPtr> &graphs,
const std::vector<IdArray> &always_preserve) { const std::vector<IdArray> &always_preserve) {
return CompactGraphsGPU<int32_t>(graphs, always_preserve); return CompactGraphsGPU<int32_t>(graphs, always_preserve);
...@@ -263,7 +263,7 @@ CompactGraphs<kDLGPU, int32_t>( ...@@ -263,7 +263,7 @@ CompactGraphs<kDLGPU, int32_t>(
template<> template<>
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>
CompactGraphs<kDLGPU, int64_t>( CompactGraphs<kDLROCM, int64_t>(
const std::vector<HeteroGraphPtr> &graphs, const std::vector<HeteroGraphPtr> &graphs,
const std::vector<IdArray> &always_preserve) { const std::vector<IdArray> &always_preserve) {
return CompactGraphsGPU<int64_t>(graphs, always_preserve); return CompactGraphsGPU<int64_t>(graphs, always_preserve);
......
...@@ -82,7 +82,7 @@ class DeviceNodeMapMaker { ...@@ -82,7 +82,7 @@ class DeviceNodeMapMaker {
for (int64_t ntype = 0; ntype < lhs_num_ntypes; ++ntype) { for (int64_t ntype = 0; ntype < lhs_num_ntypes; ++ntype) {
const IdArray& nodes = lhs_nodes[ntype]; const IdArray& nodes = lhs_nodes[ntype];
if (nodes->shape[0] > 0) { if (nodes->shape[0] > 0) {
CHECK_EQ(nodes->ctx.device_type, kDLGPU); CHECK_EQ(nodes->ctx.device_type, kDLROCM);
node_maps->LhsHashTable(ntype).FillWithDuplicates( node_maps->LhsHashTable(ntype).FillWithDuplicates(
nodes.Ptr<IdType>(), nodes.Ptr<IdType>(),
nodes->shape[0], nodes->shape[0],
...@@ -127,7 +127,7 @@ class DeviceNodeMapMaker { ...@@ -127,7 +127,7 @@ class DeviceNodeMapMaker {
for (int64_t ntype = 0; ntype < lhs_num_ntypes; ++ntype) { for (int64_t ntype = 0; ntype < lhs_num_ntypes; ++ntype) {
const IdArray& nodes = lhs_nodes[ntype]; const IdArray& nodes = lhs_nodes[ntype];
if (nodes->shape[0] > 0) { if (nodes->shape[0] > 0) {
CHECK_EQ(nodes->ctx.device_type, kDLGPU); CHECK_EQ(nodes->ctx.device_type, kDLROCM);
node_maps->LhsHashTable(ntype).FillWithUnique( node_maps->LhsHashTable(ntype).FillWithUnique(
nodes.Ptr<IdType>(), nodes.Ptr<IdType>(),
nodes->shape[0], nodes->shape[0],
...@@ -154,7 +154,7 @@ class DeviceNodeMapMaker { ...@@ -154,7 +154,7 @@ class DeviceNodeMapMaker {
// Since partial specialization is not allowed for functions, use this as an // Since partial specialization is not allowed for functions, use this as an
// intermediate for ToBlock where XPU = kDLGPU. // intermediate for ToBlock where XPU = kDLROCM.
template<typename IdType> template<typename IdType>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>>
ToBlockGPU( ToBlockGPU(
...@@ -170,7 +170,7 @@ ToBlockGPU( ...@@ -170,7 +170,7 @@ ToBlockGPU(
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
hipStream_t stream = runtime::getCurrentCUDAStream(); hipStream_t stream = runtime::getCurrentCUDAStream();
CHECK_EQ(ctx.device_type, kDLGPU); CHECK_EQ(ctx.device_type, kDLROCM);
for (const auto& nodes : rhs_nodes) { for (const auto& nodes : rhs_nodes) {
CHECK_EQ(ctx.device_type, nodes->ctx.device_type); CHECK_EQ(ctx.device_type, nodes->ctx.device_type);
} }
...@@ -383,7 +383,7 @@ ToBlockGPU( ...@@ -383,7 +383,7 @@ ToBlockGPU(
// functions are the same. // functions are the same.
// Using template<> fails to export the symbols. // Using template<> fails to export the symbols.
std::tuple<HeteroGraphPtr, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>>
// ToBlock<kDLGPU, int32_t> // ToBlock<kDLROCM, int32_t>
ToBlockGPU32( ToBlockGPU32(
HeteroGraphPtr graph, HeteroGraphPtr graph,
const std::vector<IdArray> &rhs_nodes, const std::vector<IdArray> &rhs_nodes,
...@@ -393,7 +393,7 @@ ToBlockGPU32( ...@@ -393,7 +393,7 @@ ToBlockGPU32(
} }
std::tuple<HeteroGraphPtr, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>>
// ToBlock<kDLGPU, int64_t> // ToBlock<kDLROCM, int64_t>
ToBlockGPU64( ToBlockGPU64(
HeteroGraphPtr graph, HeteroGraphPtr graph,
const std::vector<IdArray> &rhs_nodes, const std::vector<IdArray> &rhs_nodes,
......
...@@ -923,36 +923,36 @@ void NNDescent(const NDArray& points, const IdArray& offsets, ...@@ -923,36 +923,36 @@ void NNDescent(const NDArray& points, const IdArray& offsets,
device->FreeWorkspace(ctx, sum_temp_storage); device->FreeWorkspace(ctx, sum_temp_storage);
} }
template void KNN<kDLGPU, float, int32_t>( template void KNN<kDLROCM, float, int32_t>(
const NDArray& data_points, const IdArray& data_offsets, const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets, const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm); const int k, IdArray result, const std::string& algorithm);
template void KNN<kDLGPU, float, int64_t>( template void KNN<kDLROCM, float, int64_t>(
const NDArray& data_points, const IdArray& data_offsets, const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets, const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm); const int k, IdArray result, const std::string& algorithm);
template void KNN<kDLGPU, double, int32_t>( template void KNN<kDLROCM, double, int32_t>(
const NDArray& data_points, const IdArray& data_offsets, const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets, const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm); const int k, IdArray result, const std::string& algorithm);
template void KNN<kDLGPU, double, int64_t>( template void KNN<kDLROCM, double, int64_t>(
const NDArray& data_points, const IdArray& data_offsets, const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets, const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm); const int k, IdArray result, const std::string& algorithm);
template void NNDescent<kDLGPU, float, int32_t>( template void NNDescent<kDLROCM, float, int32_t>(
const NDArray& points, const IdArray& offsets, const NDArray& points, const IdArray& offsets,
IdArray result, const int k, const int num_iters, IdArray result, const int k, const int num_iters,
const int num_candidates, const double delta); const int num_candidates, const double delta);
template void NNDescent<kDLGPU, float, int64_t>( template void NNDescent<kDLROCM, float, int64_t>(
const NDArray& points, const IdArray& offsets, const NDArray& points, const IdArray& offsets,
IdArray result, const int k, const int num_iters, IdArray result, const int k, const int num_iters,
const int num_candidates, const double delta); const int num_candidates, const double delta);
template void NNDescent<kDLGPU, double, int32_t>( template void NNDescent<kDLROCM, double, int32_t>(
const NDArray& points, const IdArray& offsets, const NDArray& points, const IdArray& offsets,
IdArray result, const int k, const int num_iters, IdArray result, const int k, const int num_iters,
const int num_candidates, const double delta); const int num_candidates, const double delta);
template void NNDescent<kDLGPU, double, int64_t>( template void NNDescent<kDLROCM, double, int64_t>(
const NDArray& points, const IdArray& offsets, const NDArray& points, const IdArray& offsets,
IdArray result, const int k, const int num_iters, IdArray result, const int k, const int num_iters,
const int num_candidates, const double delta); const int num_candidates, const double delta);
......
...@@ -172,7 +172,7 @@ ToBlockGPU64(HeteroGraphPtr, const std::vector<IdArray>&, bool, std::vector<IdAr ...@@ -172,7 +172,7 @@ ToBlockGPU64(HeteroGraphPtr, const std::vector<IdArray>&, bool, std::vector<IdAr
template<> template<>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>>
ToBlock<kDLGPU, int32_t>(HeteroGraphPtr graph, ToBlock<kDLROCM, int32_t>(HeteroGraphPtr graph,
const std::vector<IdArray> &rhs_nodes, const std::vector<IdArray> &rhs_nodes,
bool include_rhs_in_lhs, bool include_rhs_in_lhs,
std::vector<IdArray>* const lhs_nodes) { std::vector<IdArray>* const lhs_nodes) {
...@@ -181,7 +181,7 @@ ToBlock<kDLGPU, int32_t>(HeteroGraphPtr graph, ...@@ -181,7 +181,7 @@ ToBlock<kDLGPU, int32_t>(HeteroGraphPtr graph,
template<> template<>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>>
ToBlock<kDLGPU, int64_t>(HeteroGraphPtr graph, ToBlock<kDLROCM, int64_t>(HeteroGraphPtr graph,
const std::vector<IdArray> &rhs_nodes, const std::vector<IdArray> &rhs_nodes,
bool include_rhs_in_lhs, bool include_rhs_in_lhs,
std::vector<IdArray>* const lhs_nodes) { std::vector<IdArray>* const lhs_nodes) {
......
...@@ -214,7 +214,7 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -214,7 +214,7 @@ class UnitGraph : public BaseHeteroGraph {
* \note The graph will be pinned inplace. Behavior depends on the current context, * \note The graph will be pinned inplace. Behavior depends on the current context,
* kDLCPU: will be pinned; * kDLCPU: will be pinned;
* IsPinned: directly return; * IsPinned: directly return;
* kDLGPU: invalid, will throw an error. * kDLROCM: invalid, will throw an error.
* The context check is deferred to pinning the NDArray. * The context check is deferred to pinning the NDArray.
*/ */
void PinMemory_() override; void PinMemory_() override;
......
...@@ -377,12 +377,12 @@ GeneratePermutationFromRemainder( ...@@ -377,12 +377,12 @@ GeneratePermutationFromRemainder(
template std::pair<IdArray, IdArray> template std::pair<IdArray, IdArray>
GeneratePermutationFromRemainder<kDLGPU, int32_t>( GeneratePermutationFromRemainder<kDLROCM, int32_t>(
int64_t array_size, int64_t array_size,
int num_parts, int num_parts,
IdArray in_idx); IdArray in_idx);
template std::pair<IdArray, IdArray> template std::pair<IdArray, IdArray>
GeneratePermutationFromRemainder<kDLGPU, int64_t>( GeneratePermutationFromRemainder<kDLROCM, int64_t>(
int64_t array_size, int64_t array_size,
int num_parts, int num_parts,
IdArray in_idx); IdArray in_idx);
...@@ -421,11 +421,11 @@ IdArray MapToLocalFromRemainder( ...@@ -421,11 +421,11 @@ IdArray MapToLocalFromRemainder(
} }
template IdArray template IdArray
MapToLocalFromRemainder<kDLGPU, int32_t>( MapToLocalFromRemainder<kDLROCM, int32_t>(
int num_parts, int num_parts,
IdArray in_idx); IdArray in_idx);
template IdArray template IdArray
MapToLocalFromRemainder<kDLGPU, int64_t>( MapToLocalFromRemainder<kDLROCM, int64_t>(
int num_parts, int num_parts,
IdArray in_idx); IdArray in_idx);
...@@ -469,12 +469,12 @@ IdArray MapToGlobalFromRemainder( ...@@ -469,12 +469,12 @@ IdArray MapToGlobalFromRemainder(
} }
template IdArray template IdArray
MapToGlobalFromRemainder<kDLGPU, int32_t>( MapToGlobalFromRemainder<kDLROCM, int32_t>(
int num_parts, int num_parts,
IdArray in_idx, IdArray in_idx,
int part_id); int part_id);
template IdArray template IdArray
MapToGlobalFromRemainder<kDLGPU, int64_t>( MapToGlobalFromRemainder<kDLROCM, int64_t>(
int num_parts, int num_parts,
IdArray in_idx, IdArray in_idx,
int part_id); int part_id);
...@@ -599,25 +599,25 @@ GeneratePermutationFromRange( ...@@ -599,25 +599,25 @@ GeneratePermutationFromRange(
template std::pair<IdArray, IdArray> template std::pair<IdArray, IdArray>
GeneratePermutationFromRange<kDLGPU, int32_t, int32_t>( GeneratePermutationFromRange<kDLROCM, int32_t, int32_t>(
int64_t array_size, int64_t array_size,
int num_parts, int num_parts,
IdArray range, IdArray range,
IdArray in_idx); IdArray in_idx);
template std::pair<IdArray, IdArray> template std::pair<IdArray, IdArray>
GeneratePermutationFromRange<kDLGPU, int64_t, int32_t>( GeneratePermutationFromRange<kDLROCM, int64_t, int32_t>(
int64_t array_size, int64_t array_size,
int num_parts, int num_parts,
IdArray range, IdArray range,
IdArray in_idx); IdArray in_idx);
template std::pair<IdArray, IdArray> template std::pair<IdArray, IdArray>
GeneratePermutationFromRange<kDLGPU, int32_t, int64_t>( GeneratePermutationFromRange<kDLROCM, int32_t, int64_t>(
int64_t array_size, int64_t array_size,
int num_parts, int num_parts,
IdArray range, IdArray range,
IdArray in_idx); IdArray in_idx);
template std::pair<IdArray, IdArray> template std::pair<IdArray, IdArray>
GeneratePermutationFromRange<kDLGPU, int64_t, int64_t>( GeneratePermutationFromRange<kDLROCM, int64_t, int64_t>(
int64_t array_size, int64_t array_size,
int num_parts, int num_parts,
IdArray range, IdArray range,
...@@ -658,22 +658,22 @@ IdArray MapToLocalFromRange( ...@@ -658,22 +658,22 @@ IdArray MapToLocalFromRange(
} }
template IdArray template IdArray
MapToLocalFromRange<kDLGPU, int32_t, int32_t>( MapToLocalFromRange<kDLROCM, int32_t, int32_t>(
int num_parts, int num_parts,
IdArray range, IdArray range,
IdArray in_idx); IdArray in_idx);
template IdArray template IdArray
MapToLocalFromRange<kDLGPU, int64_t, int32_t>( MapToLocalFromRange<kDLROCM, int64_t, int32_t>(
int num_parts, int num_parts,
IdArray range, IdArray range,
IdArray in_idx); IdArray in_idx);
template IdArray template IdArray
MapToLocalFromRange<kDLGPU, int32_t, int64_t>( MapToLocalFromRange<kDLROCM, int32_t, int64_t>(
int num_parts, int num_parts,
IdArray range, IdArray range,
IdArray in_idx); IdArray in_idx);
template IdArray template IdArray
MapToLocalFromRange<kDLGPU, int64_t, int64_t>( MapToLocalFromRange<kDLROCM, int64_t, int64_t>(
int num_parts, int num_parts,
IdArray range, IdArray range,
IdArray in_idx); IdArray in_idx);
...@@ -721,25 +721,25 @@ IdArray MapToGlobalFromRange( ...@@ -721,25 +721,25 @@ IdArray MapToGlobalFromRange(
} }
template IdArray template IdArray
MapToGlobalFromRange<kDLGPU, int32_t, int32_t>( MapToGlobalFromRange<kDLROCM, int32_t, int32_t>(
int num_parts, int num_parts,
IdArray range, IdArray range,
IdArray in_idx, IdArray in_idx,
int part_id); int part_id);
template IdArray template IdArray
MapToGlobalFromRange<kDLGPU, int64_t, int32_t>( MapToGlobalFromRange<kDLROCM, int64_t, int32_t>(
int num_parts, int num_parts,
IdArray range, IdArray range,
IdArray in_idx, IdArray in_idx,
int part_id); int part_id);
template IdArray template IdArray
MapToGlobalFromRange<kDLGPU, int32_t, int64_t>( MapToGlobalFromRange<kDLROCM, int32_t, int64_t>(
int num_parts, int num_parts,
IdArray range, IdArray range,
IdArray in_idx, IdArray in_idx,
int part_id); int part_id);
template IdArray template IdArray
MapToGlobalFromRange<kDLGPU, int64_t, int64_t>( MapToGlobalFromRange<kDLROCM, int64_t, int64_t>(
int num_parts, int num_parts,
IdArray range, IdArray range,
IdArray in_idx, IdArray in_idx,
......
...@@ -46,9 +46,9 @@ class RemainderPartition : public NDArrayPartition { ...@@ -46,9 +46,9 @@ class RemainderPartition : public NDArrayPartition {
IdArray in_idx) const override { IdArray in_idx) const override {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx; auto ctx = in_idx->ctx;
if (ctx.device_type == kDLGPU) { if (ctx.device_type == kDLROCM) {
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, { ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
return impl::GeneratePermutationFromRemainder<kDLGPU, IdType>( return impl::GeneratePermutationFromRemainder<kDLROCM, IdType>(
ArraySize(), NumParts(), in_idx); ArraySize(), NumParts(), in_idx);
}); });
} }
...@@ -64,9 +64,9 @@ class RemainderPartition : public NDArrayPartition { ...@@ -64,9 +64,9 @@ class RemainderPartition : public NDArrayPartition {
IdArray in_idx) const override { IdArray in_idx) const override {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx; auto ctx = in_idx->ctx;
if (ctx.device_type == kDLGPU) { if (ctx.device_type == kDLROCM) {
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, { ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
return impl::MapToLocalFromRemainder<kDLGPU, IdType>( return impl::MapToLocalFromRemainder<kDLROCM, IdType>(
NumParts(), in_idx); NumParts(), in_idx);
}); });
} }
...@@ -83,9 +83,9 @@ class RemainderPartition : public NDArrayPartition { ...@@ -83,9 +83,9 @@ class RemainderPartition : public NDArrayPartition {
const int part_id) const override { const int part_id) const override {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx; auto ctx = in_idx->ctx;
if (ctx.device_type == kDLGPU) { if (ctx.device_type == kDLROCM) {
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, { ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
return impl::MapToGlobalFromRemainder<kDLGPU, IdType>( return impl::MapToGlobalFromRemainder<kDLROCM, IdType>(
NumParts(), in_idx, part_id); NumParts(), in_idx, part_id);
}); });
} }
...@@ -118,7 +118,7 @@ class RangePartition : public NDArrayPartition { ...@@ -118,7 +118,7 @@ class RangePartition : public NDArrayPartition {
// have only one CPU context, and can safely copy the array to that. // have only one CPU context, and can safely copy the array to that.
range_cpu_(range.CopyTo(DGLContext{kDLCPU, 0})) { range_cpu_(range.CopyTo(DGLContext{kDLCPU, 0})) {
auto ctx = range->ctx; auto ctx = range->ctx;
if (ctx.device_type != kDLGPU) { if (ctx.device_type != kDLROCM) {
LOG(FATAL) << "The range for an NDArrayPartition is only supported " LOG(FATAL) << "The range for an NDArrayPartition is only supported "
" on GPUs. Transfer the range to the target device before " " on GPUs. Transfer the range to the target device before "
"creating the partition."; "creating the partition.";
...@@ -130,7 +130,7 @@ class RangePartition : public NDArrayPartition { ...@@ -130,7 +130,7 @@ class RangePartition : public NDArrayPartition {
IdArray in_idx) const override { IdArray in_idx) const override {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx; auto ctx = in_idx->ctx;
if (ctx.device_type == kDLGPU) { if (ctx.device_type == kDLROCM) {
if (ctx.device_type != range_->ctx.device_type || if (ctx.device_type != range_->ctx.device_type ||
ctx.device_id != range_->ctx.device_id) { ctx.device_id != range_->ctx.device_id) {
LOG(FATAL) << "The range for the NDArrayPartition and the input " LOG(FATAL) << "The range for the NDArrayPartition and the input "
...@@ -138,7 +138,7 @@ class RangePartition : public NDArrayPartition { ...@@ -138,7 +138,7 @@ class RangePartition : public NDArrayPartition {
} }
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, { ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, { ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, {
return impl::GeneratePermutationFromRange<kDLGPU, IdType, RangeType>( return impl::GeneratePermutationFromRange<kDLROCM, IdType, RangeType>(
ArraySize(), NumParts(), range_, in_idx); ArraySize(), NumParts(), range_, in_idx);
}); });
}); });
...@@ -155,10 +155,10 @@ class RangePartition : public NDArrayPartition { ...@@ -155,10 +155,10 @@ class RangePartition : public NDArrayPartition {
IdArray in_idx) const override { IdArray in_idx) const override {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx; auto ctx = in_idx->ctx;
if (ctx.device_type == kDLGPU) { if (ctx.device_type == kDLROCM) {
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, { ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, { ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, {
return impl::MapToLocalFromRange<kDLGPU, IdType, RangeType>( return impl::MapToLocalFromRange<kDLROCM, IdType, RangeType>(
NumParts(), range_, in_idx); NumParts(), range_, in_idx);
}); });
}); });
...@@ -176,10 +176,10 @@ class RangePartition : public NDArrayPartition { ...@@ -176,10 +176,10 @@ class RangePartition : public NDArrayPartition {
const int part_id) const override { const int part_id) const override {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx; auto ctx = in_idx->ctx;
if (ctx.device_type == kDLGPU) { if (ctx.device_type == kDLROCM) {
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, { ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, { ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, {
return impl::MapToGlobalFromRange<kDLGPU, IdType, RangeType>( return impl::MapToGlobalFromRange<kDLROCM, IdType, RangeType>(
NumParts(), range_, in_idx, part_id); NumParts(), range_, in_idx, part_id);
}); });
}); });
......
...@@ -29,7 +29,7 @@ DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed") ...@@ -29,7 +29,7 @@ DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed")
} }
}); });
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
if (DeviceAPI::Get(kDLGPU)->IsAvailable()) { if (DeviceAPI::Get(kDLROCM)->IsAvailable()) {
auto* thr_entry = CUDAThreadEntry::ThreadLocal(); auto* thr_entry = CUDAThreadEntry::ThreadLocal();
if (!thr_entry->curand_gen) { if (!thr_entry->curand_gen) {
CURAND_CALL(hiprandCreateGenerator(&thr_entry->curand_gen, HIPRAND_RNG_PSEUDO_DEFAULT)); CURAND_CALL(hiprandCreateGenerator(&thr_entry->curand_gen, HIPRAND_RNG_PSEUDO_DEFAULT));
......
...@@ -27,7 +27,7 @@ namespace runtime { ...@@ -27,7 +27,7 @@ namespace runtime {
inline std::string DeviceName(int type) { inline std::string DeviceName(int type) {
switch (type) { switch (type) {
case kDLCPU: return "cpu"; case kDLCPU: return "cpu";
case kDLGPU: return "gpu"; case kDLROCM: return "gpu";
case kDLOpenCL: return "opencl"; case kDLOpenCL: return "opencl";
case kDLSDAccel: return "sdaccel"; case kDLSDAccel: return "sdaccel";
case kDLAOCL: return "aocl"; case kDLAOCL: return "aocl";
......
...@@ -141,7 +141,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -141,7 +141,7 @@ class CUDADeviceAPI final : public DeviceAPI {
hipStream_t cu_stream = static_cast<hipStream_t>(stream); hipStream_t cu_stream = static_cast<hipStream_t>(stream);
from = static_cast<const char*>(from) + from_offset; from = static_cast<const char*>(from) + from_offset;
to = static_cast<char*>(to) + to_offset; to = static_cast<char*>(to) + to_offset;
if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLGPU) { if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLROCM) {
CUDA_CALL(hipSetDevice(ctx_from.device_id)); CUDA_CALL(hipSetDevice(ctx_from.device_id));
if (ctx_from.device_id == ctx_to.device_id) { if (ctx_from.device_id == ctx_to.device_id) {
GPUCopy(from, to, size, hipMemcpyDeviceToDevice, cu_stream); GPUCopy(from, to, size, hipMemcpyDeviceToDevice, cu_stream);
...@@ -150,10 +150,10 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -150,10 +150,10 @@ class CUDADeviceAPI final : public DeviceAPI {
from, ctx_from.device_id, from, ctx_from.device_id,
size, cu_stream)); size, cu_stream));
} }
} else if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLCPU) { } else if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLCPU) {
CUDA_CALL(hipSetDevice(ctx_from.device_id)); CUDA_CALL(hipSetDevice(ctx_from.device_id));
GPUCopy(from, to, size, hipMemcpyDeviceToHost, cu_stream); GPUCopy(from, to, size, hipMemcpyDeviceToHost, cu_stream);
} else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLGPU) { } else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLROCM) {
CUDA_CALL(hipSetDevice(ctx_to.device_id)); CUDA_CALL(hipSetDevice(ctx_to.device_id));
GPUCopy(from, to, size, hipMemcpyHostToDevice, cu_stream); GPUCopy(from, to, size, hipMemcpyHostToDevice, cu_stream);
} else { } else {
...@@ -314,7 +314,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -314,7 +314,7 @@ class CUDADeviceAPI final : public DeviceAPI {
typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore; typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore;
CUDAThreadEntry::CUDAThreadEntry() CUDAThreadEntry::CUDAThreadEntry()
: pool(kDLGPU, CUDADeviceAPI::Global()) { : pool(kDLROCM, CUDADeviceAPI::Global()) {
} }
CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() {
......
...@@ -618,7 +618,7 @@ void NCCLCommunicator::AllToAllV( ...@@ -618,7 +618,7 @@ void NCCLCommunicator::AllToAllV(
int dev_id; int dev_id;
CUDA_CALL(hipGetDevice(&dev_id)); CUDA_CALL(hipGetDevice(&dev_id));
DGLContext ctx{kDLGPU, dev_id}; DGLContext ctx{kDLROCM, dev_id};
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
auto dtype = DLDataTypeTraits<DType>::dtype; auto dtype = DLDataTypeTraits<DType>::dtype;
...@@ -680,7 +680,7 @@ void NCCLCommunicator::AllToAll( ...@@ -680,7 +680,7 @@ void NCCLCommunicator::AllToAll(
#else #else
int dev_id; int dev_id;
CUDA_CALL(hipGetDevice(&dev_id)); CUDA_CALL(hipGetDevice(&dev_id));
DGLContext ctx{kDLGPU, dev_id}; DGLContext ctx{kDLROCM, dev_id};
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
auto dtype = DLDataTypeTraits<IdType>::dtype; auto dtype = DLDataTypeTraits<IdType>::dtype;
......
...@@ -262,7 +262,7 @@ void NDArray::PinContainer(NDArray::Container* ptr) { ...@@ -262,7 +262,7 @@ void NDArray::PinContainer(NDArray::Container* ptr) {
auto* tensor = &(ptr->dl_tensor); auto* tensor = &(ptr->dl_tensor);
CHECK_EQ(tensor->ctx.device_type, kDLCPU) CHECK_EQ(tensor->ctx.device_type, kDLCPU)
<< "Only NDArray on CPU can be pinned"; << "Only NDArray on CPU can be pinned";
DeviceAPI::Get(kDLGPU)->PinData(tensor->data, GetDataSize(*tensor)); DeviceAPI::Get(kDLROCM)->PinData(tensor->data, GetDataSize(*tensor));
ptr->pinned_by_dgl_ = true; ptr->pinned_by_dgl_ = true;
} }
...@@ -275,14 +275,14 @@ void NDArray::UnpinContainer(NDArray::Container* ptr) { ...@@ -275,14 +275,14 @@ void NDArray::UnpinContainer(NDArray::Container* ptr) {
// 1. not pinned, do nothing // 1. not pinned, do nothing
if (!container_is_pinned) return; if (!container_is_pinned) return;
// 2. pinned by DGL, unpin it // 2. pinned by DGL, unpin it
DeviceAPI::Get(kDLGPU)->UnpinData(ptr->dl_tensor.data); DeviceAPI::Get(kDLROCM)->UnpinData(ptr->dl_tensor.data);
ptr->pinned_by_dgl_ = false; ptr->pinned_by_dgl_ = false;
} }
void NDArray::RecordStream(DGLArray* tensor, DGLStreamHandle stream) { void NDArray::RecordStream(DGLArray* tensor, DGLStreamHandle stream) {
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
CHECK(td->IsAvailable()) << "RecordStream only works when TensorAdaptor is available."; CHECK(td->IsAvailable()) << "RecordStream only works when TensorAdaptor is available.";
CHECK_EQ(tensor->ctx.device_type, kDLGPU) CHECK_EQ(tensor->ctx.device_type, kDLROCM)
<< "RecordStream only works with GPU tensors."; << "RecordStream only works with GPU tensors.";
td->RecordStream(tensor->data, stream, tensor->ctx.device_id); td->RecordStream(tensor->data, stream, tensor->ctx.device_id);
...@@ -353,7 +353,7 @@ bool NDArray::IsContainerPinned(NDArray::Container* ptr) { ...@@ -353,7 +353,7 @@ bool NDArray::IsContainerPinned(NDArray::Container* ptr) {
if (tensor->ctx.device_type != kDLCPU) if (tensor->ctx.device_type != kDLCPU)
return false; return false;
// ... and CUDA device API is enabled, and the tensor is indeed in pinned memory. // ... and CUDA device API is enabled, and the tensor is indeed in pinned memory.
auto device = DeviceAPI::Get(kDLGPU, true); auto device = DeviceAPI::Get(kDLROCM, true);
return device && device->IsPinned(tensor->data); return device && device->IsPinned(tensor->data);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment