"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "bdaccc8270545aaabb8f31a05ad8afec39882704"
Unverified Commit 8ecbfa57 authored by Ilia Taraban's avatar Ilia Taraban Committed by GitHub
Browse files

[Fix] restore SpMMSumCsrNaive function for float and double (#5615)

parent c5e8481c
...@@ -43,7 +43,41 @@ using AccType = typename std::conditional< ...@@ -43,7 +43,41 @@ using AccType = typename std::conditional<
* for the computation of different nodes. * for the computation of different nodes.
*/ */
template <typename IdType, typename DType, typename Op> template <typename IdType, typename DType, typename Op>
void SpMMSumCsrNaive( typename std::enable_if<!std::is_same<DType, BFloat16>::value, void>::type
SpMMSumCsrNaive(
const BcastOff& bcast, const CSRMatrix& csr, const DType* X, const DType* W,
DType* O) {
const bool has_idx = !IsNullArray(csr.data);
const IdType* indptr = csr.indptr.Ptr<IdType>();
const IdType* indices = csr.indices.Ptr<IdType>();
const IdType* edges = csr.data.Ptr<IdType>();
int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len;
runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
for (auto rid = b; rid < e; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
DType* out_off = O + rid * dim;
for (IdType j = row_start; j < row_end; ++j) {
const IdType cid = indices[j];
const IdType eid = has_idx ? edges[j] : j;
for (int64_t k = 0; k < dim; ++k) {
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* lhs_off =
Op::use_lhs ? X + cid * lhs_dim + lhs_add : nullptr;
const DType* rhs_off =
Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
out_off[k] += Op::Call(lhs_off, rhs_off);
}
}
}
});
}
// Naive implementation with additional accumulator, which prevents accuracy
// degradation in less precise data types, like bfloat16.
template <typename IdType, typename DType, typename Op>
typename std::enable_if<std::is_same<DType, BFloat16>::value, void>::type
SpMMSumCsrNaive(
const BcastOff& bcast, const CSRMatrix& csr, const DType* X, const DType* W, const BcastOff& bcast, const CSRMatrix& csr, const DType* X, const DType* W,
DType* O) { DType* O) {
const bool has_idx = !IsNullArray(csr.data); const bool has_idx = !IsNullArray(csr.data);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment