Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
b25bbe64
Unverified
Commit
b25bbe64
authored
Sep 15, 2020
by
Zihao Ye
Committed by
GitHub
Sep 15, 2020
Browse files
Loop reorder (#2201)
parent
233e8198
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
25 deletions
+24
-25
src/array/cpu/spmm.h
src/array/cpu/spmm.h
+24
-25
No files found.
src/array/cpu/spmm.h
View file @
b25bbe64
...
@@ -44,19 +44,20 @@ void SpMMSumCsr(
...
@@ -44,19 +44,20 @@ void SpMMSumCsr(
#pragma omp parallel for
#pragma omp parallel for
for
(
IdType
rid
=
0
;
rid
<
csr
.
num_rows
;
++
rid
)
{
for
(
IdType
rid
=
0
;
rid
<
csr
.
num_rows
;
++
rid
)
{
const
IdType
row_start
=
indptr
[
rid
],
row_end
=
indptr
[
rid
+
1
];
const
IdType
row_start
=
indptr
[
rid
],
row_end
=
indptr
[
rid
+
1
];
DType
*
out_off
=
O
+
rid
*
dim
;
DType
*
out_off
=
O
+
rid
*
dim
;
for
(
int64_t
k
=
0
;
k
<
dim
;
++
k
)
{
std
::
fill
(
out_off
,
out_off
+
dim
,
0
);
DType
accum
=
0
;
for
(
IdType
j
=
row_start
;
j
<
row_end
;
++
j
)
{
for
(
IdType
j
=
row_start
;
j
<
row_end
;
++
j
)
{
const
IdType
cid
=
indices
[
j
];
const
IdType
cid
=
indices
[
j
];
const
IdType
eid
=
has_idx
?
edges
[
j
]
:
j
;
const
IdType
eid
=
has_idx
?
edges
[
j
]
:
j
;
for
(
int64_t
k
=
0
;
k
<
dim
;
++
k
)
{
const
int64_t
lhs_add
=
bcast
.
use_bcast
?
bcast
.
lhs_offset
[
k
]
:
k
;
const
int64_t
lhs_add
=
bcast
.
use_bcast
?
bcast
.
lhs_offset
[
k
]
:
k
;
const
int64_t
rhs_add
=
bcast
.
use_bcast
?
bcast
.
rhs_offset
[
k
]
:
k
;
const
int64_t
rhs_add
=
bcast
.
use_bcast
?
bcast
.
rhs_offset
[
k
]
:
k
;
const
DType
*
lhs_off
=
Op
::
use_lhs
?
X
+
cid
*
lhs_dim
+
lhs_add
:
nullptr
;
const
DType
*
lhs_off
=
const
DType
*
rhs_off
=
Op
::
use_rhs
?
W
+
eid
*
rhs_dim
+
rhs_add
:
nullptr
;
Op
::
use_lhs
?
X
+
cid
*
lhs_dim
+
lhs_add
:
nullptr
;
accum
+=
Op
::
Call
(
lhs_off
,
rhs_off
);
const
DType
*
rhs_off
=
Op
::
use_rhs
?
W
+
eid
*
rhs_dim
+
rhs_add
:
nullptr
;
out_off
[
k
]
+=
Op
::
Call
(
lhs_off
,
rhs_off
);
}
}
out_off
[
k
]
=
accum
;
}
}
}
}
}
}
...
@@ -153,30 +154,28 @@ void SpMMCmpCsr(
...
@@ -153,30 +154,28 @@ void SpMMCmpCsr(
DType
*
out_off
=
O
+
rid
*
dim
;
DType
*
out_off
=
O
+
rid
*
dim
;
IdType
*
argx_off
=
argX
+
rid
*
dim
;
IdType
*
argx_off
=
argX
+
rid
*
dim
;
IdType
*
argw_off
=
argW
+
rid
*
dim
;
IdType
*
argw_off
=
argW
+
rid
*
dim
;
for
(
int64_t
k
=
0
;
k
<
dim
;
++
k
)
{
std
::
fill
(
out_off
,
out_off
+
dim
,
Cmp
::
zero
);
DType
accum
=
Cmp
::
zero
;
if
(
Op
::
use_lhs
)
IdType
ax
=
0
,
aw
=
0
;
std
::
fill
(
argx_off
,
argx_off
+
dim
,
0
);
if
(
Op
::
use_rhs
)
std
::
fill
(
argw_off
,
argw_off
+
dim
,
0
);
for
(
IdType
j
=
row_start
;
j
<
row_end
;
++
j
)
{
for
(
IdType
j
=
row_start
;
j
<
row_end
;
++
j
)
{
const
IdType
cid
=
indices
[
j
];
const
IdType
cid
=
indices
[
j
];
const
IdType
eid
=
has_idx
?
edges
[
j
]
:
j
;
const
IdType
eid
=
has_idx
?
edges
[
j
]
:
j
;
for
(
int64_t
k
=
0
;
k
<
dim
;
++
k
)
{
const
int64_t
lhs_add
=
bcast
.
use_bcast
?
bcast
.
lhs_offset
[
k
]
:
k
;
const
int64_t
lhs_add
=
bcast
.
use_bcast
?
bcast
.
lhs_offset
[
k
]
:
k
;
const
int64_t
rhs_add
=
bcast
.
use_bcast
?
bcast
.
rhs_offset
[
k
]
:
k
;
const
int64_t
rhs_add
=
bcast
.
use_bcast
?
bcast
.
rhs_offset
[
k
]
:
k
;
const
DType
*
lhs_off
=
Op
::
use_lhs
?
X
+
cid
*
lhs_dim
+
lhs_add
:
nullptr
;
const
DType
*
lhs_off
=
Op
::
use_lhs
?
X
+
cid
*
lhs_dim
+
lhs_add
:
nullptr
;
const
DType
*
rhs_off
=
Op
::
use_rhs
?
W
+
eid
*
rhs_dim
+
rhs_add
:
nullptr
;
const
DType
*
rhs_off
=
Op
::
use_rhs
?
W
+
eid
*
rhs_dim
+
rhs_add
:
nullptr
;
const
DType
val
=
Op
::
Call
(
lhs_off
,
rhs_off
);
const
DType
val
=
Op
::
Call
(
lhs_off
,
rhs_off
);
if
(
Cmp
::
Call
(
accum
,
val
))
{
if
(
Cmp
::
Call
(
out_off
[
k
]
,
val
))
{
accum
=
val
;
out_off
[
k
]
=
val
;
if
(
Op
::
use_lhs
)
if
(
Op
::
use_lhs
)
a
x
=
cid
;
a
rgx_off
[
k
]
=
cid
;
if
(
Op
::
use_rhs
)
if
(
Op
::
use_rhs
)
a
w
=
eid
;
a
rgw_off
[
k
]
=
eid
;
}
}
}
}
out_off
[
k
]
=
accum
;
if
(
Op
::
use_lhs
)
argx_off
[
k
]
=
ax
;
if
(
Op
::
use_rhs
)
argw_off
[
k
]
=
aw
;
}
}
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment