Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
8ecbfa57
Unverified
Commit
8ecbfa57
authored
Apr 27, 2023
by
Ilia Taraban
Committed by
GitHub
Apr 27, 2023
Browse files
[Fix] restore SpMMSumCsrNaive function for float and double (#5615)
parent
c5e8481c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
1 deletion
+35
-1
src/array/cpu/spmm.h
src/array/cpu/spmm.h
+35
-1
No files found.
src/array/cpu/spmm.h
View file @
8ecbfa57
...
@@ -43,7 +43,41 @@ using AccType = typename std::conditional<
...
@@ -43,7 +43,41 @@ using AccType = typename std::conditional<
* for the computation of different nodes.
* for the computation of different nodes.
*/
*/
template
<
typename
IdType
,
typename
DType
,
typename
Op
>
template
<
typename
IdType
,
typename
DType
,
typename
Op
>
void
SpMMSumCsrNaive
(
typename
std
::
enable_if
<!
std
::
is_same
<
DType
,
BFloat16
>::
value
,
void
>::
type
SpMMSumCsrNaive
(
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
const
DType
*
X
,
const
DType
*
W
,
DType
*
O
)
{
const
bool
has_idx
=
!
IsNullArray
(
csr
.
data
);
const
IdType
*
indptr
=
csr
.
indptr
.
Ptr
<
IdType
>
();
const
IdType
*
indices
=
csr
.
indices
.
Ptr
<
IdType
>
();
const
IdType
*
edges
=
csr
.
data
.
Ptr
<
IdType
>
();
int64_t
dim
=
bcast
.
out_len
,
lhs_dim
=
bcast
.
lhs_len
,
rhs_dim
=
bcast
.
rhs_len
;
runtime
::
parallel_for
(
0
,
csr
.
num_rows
,
[
&
](
size_t
b
,
size_t
e
)
{
for
(
auto
rid
=
b
;
rid
<
e
;
++
rid
)
{
const
IdType
row_start
=
indptr
[
rid
],
row_end
=
indptr
[
rid
+
1
];
DType
*
out_off
=
O
+
rid
*
dim
;
for
(
IdType
j
=
row_start
;
j
<
row_end
;
++
j
)
{
const
IdType
cid
=
indices
[
j
];
const
IdType
eid
=
has_idx
?
edges
[
j
]
:
j
;
for
(
int64_t
k
=
0
;
k
<
dim
;
++
k
)
{
const
int64_t
lhs_add
=
bcast
.
use_bcast
?
bcast
.
lhs_offset
[
k
]
:
k
;
const
int64_t
rhs_add
=
bcast
.
use_bcast
?
bcast
.
rhs_offset
[
k
]
:
k
;
const
DType
*
lhs_off
=
Op
::
use_lhs
?
X
+
cid
*
lhs_dim
+
lhs_add
:
nullptr
;
const
DType
*
rhs_off
=
Op
::
use_rhs
?
W
+
eid
*
rhs_dim
+
rhs_add
:
nullptr
;
out_off
[
k
]
+=
Op
::
Call
(
lhs_off
,
rhs_off
);
}
}
}
});
}
// Naive implementation with additional accumulator, which prevents accuracy
// degradation in less precise data types, like bfloat16.
template
<
typename
IdType
,
typename
DType
,
typename
Op
>
typename
std
::
enable_if
<
std
::
is_same
<
DType
,
BFloat16
>::
value
,
void
>::
type
SpMMSumCsrNaive
(
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
const
DType
*
X
,
const
DType
*
W
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
const
DType
*
X
,
const
DType
*
W
,
DType
*
O
)
{
DType
*
O
)
{
const
bool
has_idx
=
!
IsNullArray
(
csr
.
data
);
const
bool
has_idx
=
!
IsNullArray
(
csr
.
data
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment