Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
906c97e4
Commit
906c97e4
authored
Nov 05, 2020
by
rusty1s
Browse files
multi-thrading in SPMM (CPU)
parent
25700066
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
8 deletions
+24
-8
csrc/cpu/spmm_cpu.cpp
csrc/cpu/spmm_cpu.cpp
+13
-8
setup.py
setup.py
+11
-0
No files found.
csrc/cpu/spmm_cpu.cpp
View file @
906c97e4
#include "spmm_cpu.h"
#include "spmm_cpu.h"
#include <ATen/Parallel.h>
#include "reducer.h"
#include "reducer.h"
#include "utils.h"
#include "utils.h"
...
@@ -47,19 +49,22 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
...
@@ -47,19 +49,22 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
auto
mat_data
=
mat
.
data_ptr
<
scalar_t
>
();
auto
mat_data
=
mat
.
data_ptr
<
scalar_t
>
();
auto
out_data
=
out
.
data_ptr
<
scalar_t
>
();
auto
out_data
=
out
.
data_ptr
<
scalar_t
>
();
scalar_t
val
;
std
::
vector
<
scalar_t
>
vals
(
K
);
int64_t
row_start
,
row_end
,
c
;
std
::
vector
<
int64_t
>
args
(
K
);
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
AT_DISPATCH_HAS_VALUE
(
optional_value
,
[
&
]
{
AT_DISPATCH_HAS_VALUE
(
optional_value
,
[
&
]
{
if
(
HAS_VALUE
)
{
if
(
HAS_VALUE
)
{
value_data
=
optional_value
.
value
().
data_ptr
<
scalar_t
>
();
value_data
=
optional_value
.
value
().
data_ptr
<
scalar_t
>
();
}
}
for
(
auto
b
=
0
;
b
<
B
;
b
++
)
{
int64_t
grain_size
=
at
::
internal
::
GRAIN_SIZE
/
(
K
*
(
col
.
numel
()
/
M
));
for
(
auto
m
=
0
;
m
<
M
;
m
++
)
{
at
::
parallel_for
(
0
,
B
*
M
,
grain_size
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
scalar_t
val
;
std
::
vector
<
scalar_t
>
vals
(
K
);
int64_t
row_start
,
row_end
,
b
,
m
,
c
;
std
::
vector
<
int64_t
>
args
(
K
);
for
(
auto
i
=
begin
;
i
<
end
;
i
++
)
{
b
=
i
/
M
,
m
=
i
%
M
;
row_start
=
rowptr_data
[
m
],
row_end
=
rowptr_data
[
m
+
1
];
row_start
=
rowptr_data
[
m
],
row_end
=
rowptr_data
[
m
+
1
];
for
(
auto
k
=
0
;
k
<
K
;
k
++
)
for
(
auto
k
=
0
;
k
<
K
;
k
++
)
...
@@ -86,7 +91,7 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
...
@@ -86,7 +91,7 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
arg_out_data
+
offset
+
k
,
arg_out_data
+
offset
+
k
,
args
[
k
],
row_end
-
row_start
);
args
[
k
],
row_end
-
row_start
);
}
}
}
}
);
});
});
});
});
});
});
...
...
setup.py
View file @
906c97e4
...
@@ -5,6 +5,7 @@ import glob
...
@@ -5,6 +5,7 @@ import glob
from
setuptools
import
setup
,
find_packages
from
setuptools
import
setup
,
find_packages
import
torch
import
torch
from
torch.__config__
import
parallel_info
from
torch.utils.cpp_extension
import
BuildExtension
from
torch.utils.cpp_extension
import
BuildExtension
from
torch.utils.cpp_extension
import
CppExtension
,
CUDAExtension
,
CUDA_HOME
from
torch.utils.cpp_extension
import
CppExtension
,
CUDAExtension
,
CUDA_HOME
...
@@ -31,6 +32,16 @@ def get_extensions():
...
@@ -31,6 +32,16 @@ def get_extensions():
extra_compile_args
=
{
'cxx'
:
[]}
extra_compile_args
=
{
'cxx'
:
[]}
extra_link_args
=
[]
extra_link_args
=
[]
info
=
parallel_info
()
if
'parallel backend: OpenMP'
in
info
and
'OpenMP not found'
not
in
info
:
extra_compile_args
[
'cxx'
]
+=
[
'-DAT_PARALLEL_OPENMP'
]
if
sys
.
platform
==
'win32'
:
extra_compile_args
[
'cxx'
]
+=
[
'/openmp'
]
else
:
extra_compile_args
[
'cxx'
]
+=
[
'-fopenmp'
]
else
:
print
(
'Compiling without OpenMP...'
)
if
WITH_CUDA
:
if
WITH_CUDA
:
Extension
=
CUDAExtension
Extension
=
CUDAExtension
define_macros
+=
[(
'WITH_CUDA'
,
None
)]
define_macros
+=
[(
'WITH_CUDA'
,
None
)]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment