Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
2eba313c
Commit
2eba313c
authored
May 11, 2020
by
rusty1s
Browse files
fixed spspmm for cpu
parent
57852a66
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
48 deletions
+24
-48
csrc/cpu/spspmm_cpu.cpp
csrc/cpu/spspmm_cpu.cpp
+24
-48
No files found.
csrc/cpu/spspmm_cpu.cpp
View file @
2eba313c
...
@@ -48,80 +48,56 @@ spspmm_cpu(torch::Tensor rowptrA, torch::Tensor colA,
...
@@ -48,80 +48,56 @@ spspmm_cpu(torch::Tensor rowptrA, torch::Tensor colA,
auto
rowptrB_data
=
rowptrB
.
data_ptr
<
int64_t
>
();
auto
rowptrB_data
=
rowptrB
.
data_ptr
<
int64_t
>
();
auto
colB_data
=
colB
.
data_ptr
<
int64_t
>
();
auto
colB_data
=
colB
.
data_ptr
<
int64_t
>
();
// Pass 1: Compute CSR row pointer.
auto
rowptrC
=
torch
::
empty_like
(
rowptrA
);
auto
rowptrC
=
torch
::
empty_like
(
rowptrA
);
auto
rowptrC_data
=
rowptrC
.
data_ptr
<
int64_t
>
();
auto
rowptrC_data
=
rowptrC
.
data_ptr
<
int64_t
>
();
rowptrC_data
[
0
]
=
0
;
rowptrC_data
[
0
]
=
0
;
std
::
vector
<
int64_t
>
mask
(
K
,
-
1
);
torch
::
Tensor
colC
;
int64_t
nnz
=
0
,
row_nnz
,
rowA_start
,
rowA_end
,
rowB_start
,
rowB_end
,
cA
,
cB
;
for
(
auto
n
=
0
;
n
<
rowptrA
.
numel
()
-
1
;
n
++
)
{
row_nnz
=
0
;
for
(
auto
eA
=
rowptrA_data
[
n
];
eA
<
rowptrA_data
[
n
+
1
];
eA
++
)
{
cA
=
colA_data
[
eA
];
for
(
auto
eB
=
rowptrB_data
[
cA
];
eB
<
rowptrB_data
[
cA
+
1
];
eB
++
)
{
cB
=
colB_data
[
eB
];
if
(
mask
[
cB
]
!=
n
)
{
mask
[
cB
]
=
n
;
row_nnz
++
;
}
}
}
nnz
+=
row_nnz
;
rowptrC_data
[
n
+
1
]
=
nnz
;
}
// Pass 2: Compute CSR entries.
auto
colC
=
torch
::
empty
(
nnz
,
rowptrC
.
options
());
auto
colC_data
=
colC
.
data_ptr
<
int64_t
>
();
torch
::
optional
<
torch
::
Tensor
>
optional_valueC
=
torch
::
nullopt
;
torch
::
optional
<
torch
::
Tensor
>
optional_valueC
=
torch
::
nullopt
;
if
(
optional_valueA
.
has_value
())
optional_valueC
=
torch
::
empty
(
nnz
,
optional_valueA
.
value
().
options
());
AT_DISPATCH_ALL_TYPES
(
scalar_type
,
"spspmm"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES
(
scalar_type
,
"spspmm"
,
[
&
]
{
AT_DISPATCH_HAS_VALUE
(
optional_value
C
,
[
&
]
{
AT_DISPATCH_HAS_VALUE
(
optional_value
A
,
[
&
]
{
scalar_t
*
valA_data
=
nullptr
,
*
valB_data
=
nullptr
,
*
valC_data
=
nullptr
;
scalar_t
*
valA_data
=
nullptr
,
*
valB_data
=
nullptr
;
if
(
HAS_VALUE
)
{
if
(
HAS_VALUE
)
{
valA_data
=
optional_valueA
.
value
().
data_ptr
<
scalar_t
>
();
valA_data
=
optional_valueA
.
value
().
data_ptr
<
scalar_t
>
();
valB_data
=
optional_valueB
.
value
().
data_ptr
<
scalar_t
>
();
valB_data
=
optional_valueB
.
value
().
data_ptr
<
scalar_t
>
();
valC_data
=
optional_valueC
.
value
().
data_ptr
<
scalar_t
>
();
}
}
scalar_t
valA
;
rowA_start
=
0
,
nnz
=
0
;
int64_t
nnz
=
0
,
cA
,
cB
;
std
::
vector
<
scalar_t
>
vals
(
K
,
0
);
std
::
vector
<
scalar_t
>
tmp_
vals
(
K
,
0
);
for
(
auto
n
=
1
;
n
<
rowptrA
.
numel
();
n
++
)
{
std
::
vector
<
int64_t
>
cols
;
rowA_end
=
rowptrA_data
[
n
]
;
std
::
vector
<
scalar_t
>
vals
;
for
(
auto
eA
=
rowA_start
;
eA
<
rowA_end
;
eA
++
)
{
for
(
auto
rA
=
0
;
rA
<
rowptrA
.
numel
()
-
1
;
rA
++
)
{
for
(
auto
eA
=
rowptrA_data
[
rA
];
eA
<
rowptrA_data
[
rA
+
1
];
eA
++
)
{
cA
=
colA_data
[
eA
];
cA
=
colA_data
[
eA
];
if
(
HAS_VALUE
)
for
(
auto
eB
=
rowptrB_data
[
cA
];
eB
<
rowptrB_data
[
cA
+
1
];
eB
++
)
{
valA
=
valA_data
[
eA
];
rowB_start
=
rowptrB_data
[
cA
],
rowB_end
=
rowptrB_data
[
cA
+
1
];
for
(
auto
eB
=
rowB_start
;
eB
<
rowB_end
;
eB
++
)
{
cB
=
colB_data
[
eB
];
cB
=
colB_data
[
eB
];
if
(
HAS_VALUE
)
if
(
HAS_VALUE
)
vals
[
cB
]
+=
valA
*
valB_data
[
eB
];
tmp_
vals
[
cB
]
+=
valA
_data
[
eA
]
*
valB_data
[
eB
];
else
else
vals
[
cB
]
+=
1
;
tmp_
vals
[
cB
]
++
;
}
}
}
}
for
(
auto
k
=
0
;
k
<
K
;
k
++
)
{
for
(
auto
k
=
0
;
k
<
K
;
k
++
)
{
if
(
vals
[
k
]
!=
0
)
{
if
(
tmp_
vals
[
k
]
!=
0
)
{
col
C_data
[
nnz
]
=
k
;
col
s
.
push_back
(
k
)
;
if
(
HAS_VALUE
)
if
(
HAS_VALUE
)
val
C_data
[
nnz
]
=
vals
[
k
];
val
s
.
push_back
(
tmp_
vals
[
k
]
)
;
nnz
++
;
nnz
++
;
}
}
vals
[
k
]
=
(
scalar_t
)
0
;
tmp_
vals
[
k
]
=
(
scalar_t
)
0
;
}
}
rowptrC_data
[
rA
+
1
]
=
nnz
;
}
rowA_start
=
rowA_end
;
colC
=
torch
::
from_blob
(
cols
.
data
(),
{
nnz
},
colA
.
options
()).
clone
();
if
(
HAS_VALUE
)
{
optional_valueC
=
torch
::
from_blob
(
vals
.
data
(),
{
nnz
},
optional_valueA
.
value
().
options
());
optional_valueC
=
optional_valueC
.
value
().
clone
();
}
}
});
});
});
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment