Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
36d045fd
Commit
36d045fd
authored
Dec 16, 2019
by
rusty1s
Browse files
sparse matrix multiplication kernel
parent
51834e88
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
123 additions
and
9 deletions
+123
-9
cuda/spmm.cpp
cuda/spmm.cpp
+19
-0
cuda/spmm_kernel.cu
cuda/spmm_kernel.cu
+91
-0
setup.py
setup.py
+4
-0
torch_sparse/sparse.py
torch_sparse/sparse.py
+9
-9
No files found.
cuda/spmm.cpp
0 → 100644
View file @
36d045fd
#include <torch/extension.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at
::
Tensor
spmm_cuda
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
Tensor
val
,
at
::
Tensor
mat
);
at
::
Tensor
spmm
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
Tensor
val
,
at
::
Tensor
mat
)
{
CHECK_CUDA
(
rowptr
);
CHECK_CUDA
(
col
);
CHECK_CUDA
(
val
);
CHECK_CUDA
(
mat
);
return
spmm_cuda
(
rowptr
,
col
,
val
,
mat
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"spmm"
,
&
spmm
,
"Sparse Matrix Multiplication (CUDA)"
);
}
cuda/spmm_kernel.cu
0 → 100644
View file @
36d045fd
#include <ATen/ATen.h>
#include "compat.cuh"
#define THREADS 32 * 16
// Paper: Design Principles for Sparse Matrix Multiplication on the GPU
// Code: https://github.com/owensgroup/merge-spmm
template
<
typename
scalar_t
,
size_t
Y_SIZE
>
__global__
void
spmm_row_kernel
(
const
int64_t
*
rowptr_data
,
const
int64_t
*
col_data
,
const
scalar_t
*
val_data
,
const
scalar_t
*
mat_data
,
scalar_t
*
out_data
,
size_t
N
,
size_t
M
,
size_t
K
)
{
// We ignore blockIdx.y here, because threads across blockIdx.y operate on the
// same row.
int
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
warp_idx
=
thread_idx
>>
5
;
// thread_id / 32
int
lane_idx
=
thread_idx
&
(
32
-
1
);
// thread_id % 32
int
row
=
warp_idx
;
// Each warp processes exactly one row.
// Compute the column index of `mat` in which the thread is operating.
int
mat_col_idx
=
lane_idx
+
(
blockIdx
.
y
<<
5
);
// Compute the output index given in row-major order.
int
out_idx
=
row
*
K
+
lane_idx
+
(
blockIdx
.
y
<<
5
);
// Helper arrays for warp communication.
int
mat_row_all
[
Y_SIZE
];
scalar_t
val_all
[
Y_SIZE
];
int
leftover
=
K
-
(
blockIdx
.
y
<<
5
);
if
(
row
<
N
)
{
int
row_start
=
__ldg
(
rowptr_data
+
row
);
int
row_end
=
__ldg
(
rowptr_data
+
row
+
1
);
// Iterate over all col indices in parallel.
for
(
int
col_idx
=
row_start
+
lane_idx
;
col_idx
<
row_end
;
col_idx
+=
32
)
{
int
mat_row
=
__ldg
(
col_data
+
col_idx
)
*
K
;
int
val
=
__ldg
(
val_data
+
col_idx
);
scalar_t
sum
=
(
scalar_t
)
0
;
for
(
int
i
=
0
;
i
<
32
;
i
+=
Y_SIZE
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
Y_SIZE
;
j
++
)
{
// Warp communication with *all* threads (mask = 0xffffffff).
// TODO: Compute real bit mask via `__ballot_sync()`.
mat_row_all
[
j
]
=
__shfl_sync
(
0xffffffff
,
mat_row
,
i
+
j
);
val_all
[
j
]
=
__shfl_sync
(
0xffffffff
,
val
,
i
+
j
);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
Y_SIZE
;
j
++
)
{
if
(
lane_idx
<
leftover
)
{
// Coalesced memory access into `mat`.
sum
+=
val_all
[
j
]
*
__ldg
(
mat_data
+
mat_row_all
[
j
]
+
mat_col_idx
);
}
}
}
if
(
lane_idx
<
leftover
)
{
out_data
[
out_idx
]
=
sum
;
}
}
}
}
at
::
Tensor
spmm_cuda
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
Tensor
val
,
at
::
Tensor
mat
)
{
// TODO: Set device
auto
N
=
rowptr
.
numel
()
-
1
;
auto
M
=
mat
.
size
(
0
);
auto
K
=
mat
.
size
(
1
);
auto
out
=
at
::
empty
({
N
,
K
},
mat
.
options
());
auto
rowptr_data
=
rowptr
.
DATA_PTR
<
int64_t
>
();
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
auto
val_data
=
val
.
DATA_PTR
<
float
>
();
auto
mat_data
=
mat
.
DATA_PTR
<
float
>
();
auto
out_data
=
out
.
DATA_PTR
<
float
>
();
auto
block_dim
=
dim3
(
THREADS
);
auto
grid_dim
=
dim3
((
N
+
THREADS
-
1
)
/
THREADS
,
(
K
+
32
-
1
)
/
32
);
spmm_row_kernel
<
float
,
32
><<<
grid_dim
,
block_dim
,
0
/*, cuda_stream */
>>>
(
rowptr_data
,
col_data
,
val_data
,
mat_data
,
out_data
,
N
,
M
,
K
);
return
out
;
}
setup.py
View file @
36d045fd
...
@@ -30,6 +30,10 @@ if CUDA_HOME is not None and GPU:
...
@@ -30,6 +30,10 @@ if CUDA_HOME is not None and GPU:
extra_link_args
=
[
'-lcusparse'
,
'-l'
,
'cusparse'
]
extra_link_args
=
[
'-lcusparse'
,
'-l'
,
'cusparse'
]
ext_modules
+=
[
ext_modules
+=
[
CUDAExtension
(
'torch_sparse.spmm_cuda'
,
[
'cuda/spmm.cpp'
,
'cuda/spmm_kernel.cu'
],
extra_link_args
=
extra_link_args
,
extra_compile_args
=
extra_compile_args
),
CUDAExtension
(
'torch_sparse.spspmm_cuda'
,
CUDAExtension
(
'torch_sparse.spspmm_cuda'
,
[
'cuda/spspmm.cpp'
,
'cuda/spspmm_kernel.cu'
],
[
'cuda/spspmm.cpp'
,
'cuda/spspmm_kernel.cu'
],
extra_link_args
=
extra_link_args
,
extra_link_args
=
extra_link_args
,
...
...
torch_sparse/sparse.py
View file @
36d045fd
...
@@ -59,11 +59,11 @@ class SparseTensor(object):
...
@@ -59,11 +59,11 @@ class SparseTensor(object):
return
self
.
_index
,
self
.
_value
return
self
.
_index
,
self
.
_value
def
csr
(
self
):
def
csr
(
self
):
return
self
.
_
col
,
self
.
_
rowptr
,
self
.
_value
return
self
.
_
rowptr
,
self
.
_
col
,
self
.
_value
def
csc
(
self
):
def
csc
(
self
):
perm
=
self
.
_arg_csr_to_csc
perm
=
self
.
_arg_csr_to_csc
return
self
.
_
row
[
perm
],
self
.
_colptr
,
self
.
_value
[
perm
]
return
self
.
_
colptr
,
self
.
_row
[
perm
]
,
self
.
_value
[
perm
]
def
is_quadratic
(
self
):
def
is_quadratic
(
self
):
return
self
.
sparse_size
[
0
]
==
self
.
sparse_size
[
1
]
return
self
.
sparse_size
[
0
]
==
self
.
sparse_size
[
1
]
...
@@ -103,24 +103,26 @@ class SparseTensor(object):
...
@@ -103,24 +103,26 @@ class SparseTensor(object):
return
self
.
__class__
.
from_storage
(
storage
)
return
self
.
__class__
.
from_storage
(
storage
)
def
matmul
(
self
,
mat2
):
def
matmul
(
self
,
mat2
):
pass
raise
NotImplementedError
def
coalesce
(
self
,
reduce
=
'add'
):
def
coalesce
(
self
,
reduce
=
'add'
):
pass
raise
NotImplementedError
def
is_coalesced
(
self
):
def
is_coalesced
(
self
):
pass
raise
NotImplementedError
def
add
(
self
,
layout
=
None
):
def
add
(
self
,
layout
=
None
):
# sub, mul, div
# sub, mul, div
# can take scalars, tensors and other sparse matrices
# can take scalars, tensors and other sparse matrices
# inplace variants can only take scalars or tensors
# inplace variants can only take scalars or tensors
pass
raise
NotImplementedError
# TODO: Slicing, (sum|max|min|prod|...), standard operators, masing, perm
def
to_dense
(
self
,
dtype
=
None
):
def
to_dense
(
self
,
dtype
=
None
):
dtype
=
dtype
or
self
.
dtype
dtype
=
dtype
or
self
.
dtype
mat
=
torch
.
zeros
(
self
.
size
(),
dtype
=
dtype
,
device
=
self
.
device
)
mat
=
torch
.
zeros
(
self
.
size
(),
dtype
=
dtype
,
device
=
self
.
device
)
mat
[
self
.
_row
,
self
.
_col
]
=
self
.
_value
or
1
mat
[
self
.
_row
,
self
.
_col
]
=
self
.
_value
if
self
.
has_value
else
1
return
mat
return
mat
def
to_scipy
(
self
):
def
to_scipy
(
self
):
...
@@ -129,8 +131,6 @@ class SparseTensor(object):
...
@@ -129,8 +131,6 @@ class SparseTensor(object):
def
to_torch_sparse_coo_tensor
(
self
):
def
to_torch_sparse_coo_tensor
(
self
):
raise
NotImplementedError
raise
NotImplementedError
# TODO: Slicing, (sum|max|min|prod|...), standard operators, masing, perm
def
__repr__
(
self
):
def
__repr__
(
self
):
i
=
' '
*
6
i
=
' '
*
6
index
,
value
=
self
.
coo
()
index
,
value
=
self
.
coo
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment