Commit aaaecbc9 authored by lisj's avatar lisj
Browse files

处理kDLGPU为kDLROCM

parent c454d419
......@@ -177,8 +177,8 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
col_sorted);
}
template COOMatrix DisjointUnionCoo<kDLGPU, int32_t>(const std::vector<COOMatrix>& coos);
template COOMatrix DisjointUnionCoo<kDLGPU, int64_t>(const std::vector<COOMatrix>& coos);
template COOMatrix DisjointUnionCoo<kDLROCM, int32_t>(const std::vector<COOMatrix>& coos);
template COOMatrix DisjointUnionCoo<kDLROCM, int64_t>(const std::vector<COOMatrix>& coos);
} // namespace impl
} // namespace aten
......
......@@ -395,74 +395,74 @@ void GatherMMScatter(const NDArray A,
}
template void GatherMM<kDLGPU, int32_t, 16>(
template void GatherMM<kDLROCM, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLGPU, int64_t, 16>(
template void GatherMM<kDLROCM, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLGPU, int32_t, 32>(
template void GatherMM<kDLROCM, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLGPU, int64_t, 32>(
template void GatherMM<kDLROCM, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLGPU, int32_t, 64>(
template void GatherMM<kDLROCM, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLGPU, int64_t, 64>(
template void GatherMM<kDLROCM, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMMScatter<kDLGPU, int32_t, 16>(
template void GatherMMScatter<kDLROCM, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLGPU, int64_t, 16>(
template void GatherMMScatter<kDLROCM, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLGPU, int32_t, 32>(
template void GatherMMScatter<kDLROCM, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLGPU, int64_t, 32>(
template void GatherMMScatter<kDLROCM, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLGPU, int32_t, 64>(
template void GatherMMScatter<kDLROCM, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLGPU, int64_t, 64>(
template void GatherMMScatter<kDLROCM, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void SegmentMM<kDLGPU, int32_t, 16>(
template void SegmentMM<kDLROCM, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLGPU, int64_t, 16>(
template void SegmentMM<kDLROCM, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLGPU, int32_t, 32>(
template void SegmentMM<kDLROCM, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLGPU, int64_t, 32>(
template void SegmentMM<kDLROCM, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLGPU, int32_t, 64>(
template void SegmentMM<kDLROCM, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLGPU, int64_t, 64>(
template void SegmentMM<kDLROCM, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMMBackwardB<kDLGPU, int32_t, 16>(
template void SegmentMMBackwardB<kDLROCM, int32_t, 16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLGPU, int64_t, 16>(
template void SegmentMMBackwardB<kDLROCM, int64_t, 16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLGPU, int32_t, 32>(
template void SegmentMMBackwardB<kDLROCM, int32_t, 32>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLGPU, int64_t, 32>(
template void SegmentMMBackwardB<kDLROCM, int64_t, 32>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLGPU, int32_t, 64>(
template void SegmentMMBackwardB<kDLROCM, int32_t, 64>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLGPU, int64_t, 64>(
template void SegmentMMBackwardB<kDLROCM, int64_t, 64>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
} // namespace aten
......
......@@ -212,9 +212,9 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
return result;
}
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDLGPU, int32_t>(
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDLROCM, int32_t>(
const CSRMatrix&, int64_t, int, bool, bool, double);
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDLGPU, int64_t>(
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDLROCM, int64_t>(
const CSRMatrix&, int64_t, int, bool, bool, double);
}; // namespace impl
......
......@@ -370,9 +370,9 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
picked_col, picked_idx);
}
template COOMatrix CSRRowWiseSamplingUniform<kDLGPU, int32_t>(
template COOMatrix CSRRowWiseSamplingUniform<kDLROCM, int32_t>(
CSRMatrix, IdArray, int64_t, bool);
template COOMatrix CSRRowWiseSamplingUniform<kDLGPU, int64_t>(
template COOMatrix CSRRowWiseSamplingUniform<kDLROCM, int64_t>(
CSRMatrix, IdArray, int64_t, bool);
} // namespace impl
......
......@@ -652,13 +652,13 @@ COOMatrix CSRRowWiseSampling(CSRMatrix mat,
picked_col, picked_idx);
}
template COOMatrix CSRRowWiseSampling<kDLGPU, int32_t, float>(
template COOMatrix CSRRowWiseSampling<kDLROCM, int32_t, float>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLGPU, int64_t, float>(
template COOMatrix CSRRowWiseSampling<kDLROCM, int64_t, float>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLGPU, int32_t, double>(
template COOMatrix CSRRowWiseSampling<kDLROCM, int32_t, double>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLGPU, int64_t, double>(
template COOMatrix CSRRowWiseSampling<kDLROCM, int64_t, double>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
} // namespace impl
......
......@@ -54,52 +54,52 @@ void SDDMMCoo(const std::string& op,
}
template void SDDMMCsr<kDLGPU, int32_t, 16>(
template void SDDMMCsr<kDLROCM, int32_t, 16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLGPU, int64_t, 16>(
template void SDDMMCsr<kDLROCM, int64_t, 16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLGPU, int32_t, 32>(
template void SDDMMCsr<kDLROCM, int32_t, 32>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLGPU, int64_t, 32>(
template void SDDMMCsr<kDLROCM, int64_t, 32>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLGPU, int32_t, 64>(
template void SDDMMCsr<kDLROCM, int32_t, 64>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLGPU, int64_t, 64>(
template void SDDMMCsr<kDLROCM, int64_t, 64>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLGPU, int32_t, 16>(
template void SDDMMCoo<kDLROCM, int32_t, 16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLGPU, int64_t, 16>(
template void SDDMMCoo<kDLROCM, int64_t, 16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLGPU, int32_t, 32>(
template void SDDMMCoo<kDLROCM, int32_t, 32>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLGPU, int64_t, 32>(
template void SDDMMCoo<kDLROCM, int64_t, 32>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLGPU, int32_t, 64>(
template void SDDMMCoo<kDLROCM, int32_t, 64>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLGPU, int64_t, 64>(
template void SDDMMCoo<kDLROCM, int64_t, 64>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
......
......@@ -42,42 +42,42 @@ void SDDMMCooHetero(const std::string& op,
}
template void SDDMMCooHetero<kDLGPU, int32_t, 16>(
template void SDDMMCooHetero<kDLROCM, int32_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLGPU, int64_t, 16>(
template void SDDMMCooHetero<kDLROCM, int64_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLGPU, int32_t, 32>(
template void SDDMMCooHetero<kDLROCM, int32_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLGPU, int64_t, 32>(
template void SDDMMCooHetero<kDLROCM, int64_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLGPU, int32_t, 64>(
template void SDDMMCooHetero<kDLROCM, int32_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLGPU, int64_t, 64>(
template void SDDMMCooHetero<kDLROCM, int64_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
......
......@@ -41,42 +41,42 @@ void SDDMMCsrHetero(const std::string& op,
});
}
template void SDDMMCsrHetero<kDLGPU, int32_t, 16>(
template void SDDMMCsrHetero<kDLROCM, int32_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLGPU, int64_t, 16>(
template void SDDMMCsrHetero<kDLROCM, int64_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLGPU, int32_t, 32>(
template void SDDMMCsrHetero<kDLROCM, int32_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLGPU, int64_t, 32>(
template void SDDMMCsrHetero<kDLROCM, int64_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLGPU, int32_t, 64>(
template void SDDMMCsrHetero<kDLROCM, int32_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLGPU, int64_t, 64>(
template void SDDMMCsrHetero<kDLROCM, int64_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
......
......@@ -73,113 +73,113 @@ void BackwardSegmentCmp(NDArray feat,
}
template void SegmentReduce<kDLGPU, int32_t, 16>(
template void SegmentReduce<kDLROCM, int32_t, 16>(
const std::string& op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLGPU, int64_t, 16>(
template void SegmentReduce<kDLROCM, int64_t, 16>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLGPU, int32_t, 32>(
template void SegmentReduce<kDLROCM, int32_t, 32>(
const std::string& op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLGPU, int64_t, 32>(
template void SegmentReduce<kDLROCM, int64_t, 32>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLGPU, int32_t, 64>(
template void SegmentReduce<kDLROCM, int32_t, 64>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLGPU, int64_t, 64>(
template void SegmentReduce<kDLROCM, int64_t, 64>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void ScatterAdd<kDLGPU, int32_t, 16>(
template void ScatterAdd<kDLROCM, int32_t, 16>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLGPU, int64_t, 16>(
template void ScatterAdd<kDLROCM, int64_t, 16>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLGPU, int32_t, 32>(
template void ScatterAdd<kDLROCM, int32_t, 32>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLGPU, int64_t, 32>(
template void ScatterAdd<kDLROCM, int64_t, 32>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLGPU, int32_t, 64>(
template void ScatterAdd<kDLROCM, int32_t, 64>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLGPU, int64_t, 64>(
template void ScatterAdd<kDLROCM, int64_t, 64>(
NDArray feat,
NDArray idx,
NDArray out);
template void UpdateGradMinMax_hetero<kDLGPU, int32_t, 16>(
template void UpdateGradMinMax_hetero<kDLROCM, int32_t, 16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLGPU, int64_t, 16>(
template void UpdateGradMinMax_hetero<kDLROCM, int64_t, 16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLGPU, int32_t, 32>(
template void UpdateGradMinMax_hetero<kDLROCM, int32_t, 32>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLGPU, int64_t, 32>(
template void UpdateGradMinMax_hetero<kDLROCM, int64_t, 32>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLGPU, int32_t, 64>(
template void UpdateGradMinMax_hetero<kDLROCM, int32_t, 64>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLGPU, int64_t, 64>(
template void UpdateGradMinMax_hetero<kDLROCM, int64_t, 64>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void BackwardSegmentCmp<kDLGPU, int32_t, 16>(
template void BackwardSegmentCmp<kDLROCM, int32_t, 16>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLGPU, int64_t, 16>(
template void BackwardSegmentCmp<kDLROCM, int64_t, 16>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLGPU, int32_t, 32>(
template void BackwardSegmentCmp<kDLROCM, int32_t, 32>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLGPU, int64_t, 32>(
template void BackwardSegmentCmp<kDLROCM, int64_t, 32>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLGPU, int32_t, 64>(
template void BackwardSegmentCmp<kDLROCM, int32_t, 64>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLGPU, int64_t, 64>(
template void BackwardSegmentCmp<kDLROCM, int64_t, 64>(
NDArray feat,
NDArray arg,
NDArray out);
......
......@@ -89,8 +89,8 @@ int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
return *rst.Ptr<IdType>();
}
template int64_t COOGetRowNNZ<kDLGPU, int32_t>(COOMatrix, int64_t);
template int64_t COOGetRowNNZ<kDLGPU, int64_t>(COOMatrix, int64_t);
template int64_t COOGetRowNNZ<kDLROCM, int32_t>(COOMatrix, int64_t);
template int64_t COOGetRowNNZ<kDLROCM, int64_t>(COOMatrix, int64_t);
template <typename IdType>
__global__ void _COOGetAllRowNNZKernel(
......@@ -137,8 +137,8 @@ NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {
}
}
template NDArray COOGetRowNNZ<kDLGPU, int32_t>(COOMatrix, NDArray);
template NDArray COOGetRowNNZ<kDLGPU, int64_t>(COOMatrix, NDArray);
template NDArray COOGetRowNNZ<kDLROCM, int32_t>(COOMatrix, NDArray);
template NDArray COOGetRowNNZ<kDLROCM, int64_t>(COOMatrix, NDArray);
} // namespace impl
} // namespace aten
......
......@@ -43,8 +43,8 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
return *out.Ptr<IdType>() != -1;
}
template bool CSRIsNonZero<kDLGPU, int32_t>(CSRMatrix, int64_t, int64_t);
template bool CSRIsNonZero<kDLGPU, int64_t>(CSRMatrix, int64_t, int64_t);
template bool CSRIsNonZero<kDLROCM, int32_t>(CSRMatrix, int64_t, int64_t);
template bool CSRIsNonZero<kDLROCM, int64_t>(CSRMatrix, int64_t, int64_t);
template <DLDeviceType XPU, typename IdType>
NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
......@@ -70,8 +70,8 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
return rst != -1;
}
template NDArray CSRIsNonZero<kDLGPU, int32_t>(CSRMatrix, NDArray, NDArray);
template NDArray CSRIsNonZero<kDLGPU, int64_t>(CSRMatrix, NDArray, NDArray);
template NDArray CSRIsNonZero<kDLROCM, int32_t>(CSRMatrix, NDArray, NDArray);
template NDArray CSRIsNonZero<kDLROCM, int64_t>(CSRMatrix, NDArray, NDArray);
///////////////////////////// CSRHasDuplicate /////////////////////////////
......@@ -117,8 +117,8 @@ bool CSRHasDuplicate(CSRMatrix csr) {
return !ret;
}
template bool CSRHasDuplicate<kDLGPU, int32_t>(CSRMatrix csr);
template bool CSRHasDuplicate<kDLGPU, int64_t>(CSRMatrix csr);
template bool CSRHasDuplicate<kDLROCM, int32_t>(CSRMatrix csr);
template bool CSRHasDuplicate<kDLROCM, int64_t>(CSRMatrix csr);
///////////////////////////// CSRGetRowNNZ /////////////////////////////
......@@ -129,8 +129,8 @@ int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
return next - cur;
}
template int64_t CSRGetRowNNZ<kDLGPU, int32_t>(CSRMatrix, int64_t);
template int64_t CSRGetRowNNZ<kDLGPU, int64_t>(CSRMatrix, int64_t);
template int64_t CSRGetRowNNZ<kDLROCM, int32_t>(CSRMatrix, int64_t);
template int64_t CSRGetRowNNZ<kDLROCM, int64_t>(CSRMatrix, int64_t);
template <typename IdType>
__global__ void _CSRGetRowNNZKernel(
......@@ -163,8 +163,8 @@ NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
return rst;
}
template NDArray CSRGetRowNNZ<kDLGPU, int32_t>(CSRMatrix, NDArray);
template NDArray CSRGetRowNNZ<kDLGPU, int64_t>(CSRMatrix, NDArray);
template NDArray CSRGetRowNNZ<kDLROCM, int32_t>(CSRMatrix, NDArray);
template NDArray CSRGetRowNNZ<kDLROCM, int64_t>(CSRMatrix, NDArray);
///////////////////////////// CSRGetRowColumnIndices /////////////////////////////
......@@ -175,8 +175,8 @@ NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
return csr.indices.CreateView({len}, csr.indices->dtype, offset);
}
template NDArray CSRGetRowColumnIndices<kDLGPU, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowColumnIndices<kDLGPU, int64_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowColumnIndices<kDLROCM, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowColumnIndices<kDLROCM, int64_t>(CSRMatrix, int64_t);
///////////////////////////// CSRGetRowData /////////////////////////////
......@@ -190,8 +190,8 @@ NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
return aten::Range(offset, offset + len, csr.indptr->dtype.bits, csr.indptr->ctx);
}
template NDArray CSRGetRowData<kDLGPU, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowData<kDLGPU, int64_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowData<kDLROCM, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowData<kDLROCM, int64_t>(CSRMatrix, int64_t);
///////////////////////////// CSRSliceRows /////////////////////////////
......@@ -216,8 +216,8 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
csr.sorted);
}
template CSRMatrix CSRSliceRows<kDLGPU, int32_t>(CSRMatrix, int64_t, int64_t);
template CSRMatrix CSRSliceRows<kDLGPU, int64_t>(CSRMatrix, int64_t, int64_t);
template CSRMatrix CSRSliceRows<kDLROCM, int32_t>(CSRMatrix, int64_t, int64_t);
template CSRMatrix CSRSliceRows<kDLROCM, int64_t>(CSRMatrix, int64_t, int64_t);
/*!
* \brief Copy data segment to output buffers
......@@ -273,8 +273,8 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
csr.sorted);
}
template CSRMatrix CSRSliceRows<kDLGPU, int32_t>(CSRMatrix , NDArray);
template CSRMatrix CSRSliceRows<kDLGPU, int64_t>(CSRMatrix , NDArray);
template CSRMatrix CSRSliceRows<kDLROCM, int32_t>(CSRMatrix , NDArray);
template CSRMatrix CSRSliceRows<kDLROCM, int64_t>(CSRMatrix , NDArray);
///////////////////////////// CSRGetDataAndIndices /////////////////////////////
......@@ -393,9 +393,9 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray co
return {ret_row, ret_col, ret_data};
}
template std::vector<NDArray> CSRGetDataAndIndices<kDLGPU, int32_t>(
template std::vector<NDArray> CSRGetDataAndIndices<kDLROCM, int32_t>(
CSRMatrix csr, NDArray rows, NDArray cols);
template std::vector<NDArray> CSRGetDataAndIndices<kDLGPU, int64_t>(
template std::vector<NDArray> CSRGetDataAndIndices<kDLROCM, int64_t>(
CSRMatrix csr, NDArray rows, NDArray cols);
///////////////////////////// CSRSliceMatrix /////////////////////////////
......@@ -502,9 +502,9 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
ret_col, ret_data);
}
template CSRMatrix CSRSliceMatrix<kDLGPU, int32_t>(
template CSRMatrix CSRSliceMatrix<kDLROCM, int32_t>(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template CSRMatrix CSRSliceMatrix<kDLGPU, int64_t>(
template CSRMatrix CSRSliceMatrix<kDLROCM, int64_t>(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
} // namespace impl
......
......@@ -147,53 +147,53 @@ void SpMMCoo(const std::string& op, const std::string& reduce,
}
}
template void SpMMCsr<kDLGPU, int32_t, 16>(
template void SpMMCsr<kDLROCM, int32_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLGPU, int64_t, 16>(
template void SpMMCsr<kDLROCM, int64_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLGPU, int32_t, 32>(
template void SpMMCsr<kDLROCM, int32_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLGPU, int64_t, 32>(
template void SpMMCsr<kDLROCM, int64_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLGPU, int32_t, 64>(
template void SpMMCsr<kDLROCM, int32_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLGPU, int64_t, 64>(
template void SpMMCsr<kDLROCM, int64_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLGPU, int32_t, 16>(
template void SpMMCoo<kDLROCM, int32_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLGPU, int64_t, 16>(
template void SpMMCoo<kDLROCM, int64_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLGPU, int32_t, 32>(
template void SpMMCoo<kDLROCM, int32_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLGPU, int64_t, 32>(
template void SpMMCoo<kDLROCM, int64_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLGPU, int32_t, 64>(
template void SpMMCoo<kDLROCM, int32_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLGPU, int64_t, 64>(
template void SpMMCoo<kDLROCM, int64_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
......
......@@ -201,37 +201,37 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
});
}
template void SpMMCsrHetero<kDLGPU, int32_t, 16>(
template void SpMMCsrHetero<kDLROCM, int32_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int64_t, 16>(
template void SpMMCsrHetero<kDLROCM, int64_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int32_t, 32>(
template void SpMMCsrHetero<kDLROCM, int32_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int64_t, 32>(
template void SpMMCsrHetero<kDLROCM, int64_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int32_t, 64>(
template void SpMMCsrHetero<kDLROCM, int32_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int64_t, 64>(
template void SpMMCsrHetero<kDLROCM, int64_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
......
......@@ -25,7 +25,7 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
std::vector<int64_t> shape{len};
CHECK(array.IsPinned());
CHECK_EQ(index->ctx.device_type, kDLGPU);
CHECK_EQ(index->ctx.device_type, kDLROCM);
for (int d = 1; d < array->ndim; ++d) {
num_feat *= array->shape[d];
......@@ -85,8 +85,8 @@ void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) {
std::vector<int64_t> shape{len};
CHECK(dest.IsPinned());
CHECK_EQ(index->ctx.device_type, kDLGPU);
CHECK_EQ(source->ctx.device_type, kDLGPU);
CHECK_EQ(index->ctx.device_type, kDLROCM);
CHECK_EQ(source->ctx.device_type, kDLROCM);
for (int d = 1; d < source->ndim; ++d) {
num_feat *= source->shape[d];
......
......@@ -23,10 +23,10 @@ DGL_REGISTER_GLOBAL("utils.filter._CAPI_DGLFilterCreateFromSet")
IdArray array = args[0];
auto ctx = array->ctx;
// TODO(nv-dlasalle): Implement CPU version.
if (ctx.device_type == kDLGPU) {
if (ctx.device_type == kDLROCM) {
#ifdef DGL_USE_CUDA
ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {
*rv = CreateSetFilter<kDLGPU, IdType>(array);
*rv = CreateSetFilter<kDLROCM, IdType>(array);
});
#else
LOG(FATAL) << "GPU support not compiled.";
......
......@@ -16,7 +16,7 @@ namespace aten {
NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
#ifdef DGL_USE_CUDA
CHECK(array.IsPinned()) << "Input array must be in pinned memory.";
CHECK_EQ(index->ctx.device_type, kDLGPU) << "Index must be on the GPU.";
CHECK_EQ(index->ctx.device_type, kDLROCM) << "Index must be on the GPU.";
CHECK_GE(array->ndim, 1) << "Input array must have at least 1 dimension.";
CHECK_EQ(index->ndim, 1) << "Index must be a 1D array.";
......@@ -34,8 +34,8 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) {
#ifdef DGL_USE_CUDA
CHECK(dest.IsPinned()) << "Destination array must be in pinned memory.";
CHECK_EQ(index->ctx.device_type, kDLGPU) << "Index must be on the GPU.";
CHECK_EQ(source->ctx.device_type, kDLGPU) << "Source array must be on the GPU.";
CHECK_EQ(index->ctx.device_type, kDLROCM) << "Index must be on the GPU.";
CHECK_EQ(source->ctx.device_type, kDLROCM) << "Source array must be on the GPU.";
CHECK_EQ(dest->dtype, source->dtype) << "Destination array and source "
"array must have the same dtype.";
CHECK_GE(dest->ndim, 1) << "Destination array must have at least 1 dimension.";
......
......@@ -183,13 +183,13 @@ void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight,
}
device->FreeWorkspace(ctx, prop);
}
template void WeightedNeighborMatching<kDLGPU, float, int32_t>(
template void WeightedNeighborMatching<kDLROCM, float, int32_t>(
const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
template void WeightedNeighborMatching<kDLGPU, float, int64_t>(
template void WeightedNeighborMatching<kDLROCM, float, int64_t>(
const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
template void WeightedNeighborMatching<kDLGPU, double, int32_t>(
template void WeightedNeighborMatching<kDLROCM, double, int32_t>(
const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
template void WeightedNeighborMatching<kDLGPU, double, int64_t>(
template void WeightedNeighborMatching<kDLROCM, double, int64_t>(
const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
/*! \brief Unweighted neighbor matching procedure (GPU version).
......@@ -222,8 +222,8 @@ void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
WeightedNeighborMatching<XPU, float, IdType>(csr, weight, result);
}
template void NeighborMatching<kDLGPU, int32_t>(const aten::CSRMatrix &csr, IdArray result);
template void NeighborMatching<kDLGPU, int64_t>(const aten::CSRMatrix &csr, IdArray result);
template void NeighborMatching<kDLROCM, int32_t>(const aten::CSRMatrix &csr, IdArray result);
template void NeighborMatching<kDLROCM, int64_t>(const aten::CSRMatrix &csr, IdArray result);
} // namespace impl
} // namespace geometry
......
......@@ -116,16 +116,16 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin
point_in_batch, dim, start_idx_data, dist_data, ret_data);
}
template void FarthestPointSampler<kDLGPU, float, int32_t>(
template void FarthestPointSampler<kDLROCM, float, int32_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDLGPU, float, int64_t>(
template void FarthestPointSampler<kDLROCM, float, int64_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDLGPU, double, int32_t>(
template void FarthestPointSampler<kDLROCM, double, int32_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDLGPU, double, int64_t>(
template void FarthestPointSampler<kDLROCM, double, int64_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
......
......@@ -237,7 +237,7 @@ class HeteroGraph : public BaseHeteroGraph {
* \note The graph will be pinned inplace. Behavior depends on the current context,
* kDLCPU: will be pinned;
* IsPinned: directly return;
* kDLGPU: invalid, will throw an error.
* kDLROCM: invalid, will throw an error.
* The context check is deferred to pinning the NDArray.
*/
void PinMemory_() override;
......
......@@ -61,11 +61,11 @@ TypeArray GetNodeTypesFromMetapath(
}
template
TypeArray GetNodeTypesFromMetapath<kDLGPU, int32_t>(
TypeArray GetNodeTypesFromMetapath<kDLROCM, int32_t>(
const HeteroGraphPtr hg,
const TypeArray metapath);
template
TypeArray GetNodeTypesFromMetapath<kDLGPU, int64_t>(
TypeArray GetNodeTypesFromMetapath<kDLROCM, int64_t>(
const HeteroGraphPtr hg,
const TypeArray metapath);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment