Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
ac5d7a78
Commit
ac5d7a78
authored
Dec 18, 2019
by
rusty1s
Browse files
all kernels
parent
8b07240b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
173 additions
and
47 deletions
+173
-47
cuda/spmm.cpp
cuda/spmm.cpp
+25
-6
cuda/spmm_kernel.cu
cuda/spmm_kernel.cu
+148
-41
No files found.
cuda/spmm.cpp
View file @
ac5d7a78
...
@@ -2,18 +2,37 @@
...
@@ -2,18 +2,37 @@
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at
::
Tensor
spmm_cuda
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
Tensor
val
,
at
::
Tensor
spmm_cuda
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
Tensor
mat
);
at
::
optional
<
at
::
Tensor
>
val
,
at
::
Tensor
mat
,
std
::
string
reduce
);
at
::
Tensor
spmm
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
Tensor
val
,
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
at
::
Tensor
mat
)
{
spmm_arg_cuda
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
optional
<
at
::
Tensor
>
val
,
at
::
Tensor
mat
,
std
::
string
reduce
);
at
::
Tensor
spmm
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
optional
<
at
::
Tensor
>
val
,
at
::
Tensor
mat
,
std
::
string
reduce
)
{
CHECK_CUDA
(
rowptr
);
CHECK_CUDA
(
col
);
if
(
val
.
has_value
())
CHECK_CUDA
(
val
.
value
());
CHECK_CUDA
(
mat
);
return
spmm_cuda
(
rowptr
,
col
,
val
,
mat
,
reduce
);
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
spmm_arg
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
optional
<
at
::
Tensor
>
val
,
at
::
Tensor
mat
,
std
::
string
reduce
)
{
CHECK_CUDA
(
rowptr
);
CHECK_CUDA
(
rowptr
);
CHECK_CUDA
(
col
);
CHECK_CUDA
(
col
);
CHECK_CUDA
(
val
);
if
(
val
.
has_value
())
CHECK_CUDA
(
val
.
value
());
CHECK_CUDA
(
mat
);
CHECK_CUDA
(
mat
);
return
spmm_cuda
(
rowptr
,
col
,
val
,
mat
);
return
spmm_
arg_
cuda
(
rowptr
,
col
,
val
,
mat
,
reduce
);
}
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"spmm"
,
&
spmm
,
"Sparse Matrix Multiplication (CUDA)"
);
m
.
def
(
"spmm"
,
&
spmm
,
"Sparse Matrix Multiplication (CUDA)"
);
m
.
def
(
"spmm_arg"
,
&
spmm_arg
,
"Sparse Matrix Multiplication With Arg (CUDA)"
);
}
}
cuda/spmm_kernel.cu
View file @
ac5d7a78
...
@@ -3,24 +3,27 @@
...
@@ -3,24 +3,27 @@
#include "compat.cuh"
#include "compat.cuh"
#define Y_SIZE 32
#define THREADS 256
#define THREADS 256
#define ADD 0
#define MEAN 1
#define MIN 2
#define MAX 3
// Paper: Design Principles for Sparse Matrix Multiplication on the GPU
// Paper: Design Principles for Sparse Matrix Multiplication on the GPU
// Code: https://github.com/owensgroup/merge-spmm
// Code: https://github.com/owensgroup/merge-spmm
template
<
typename
scalar_t
>
template
<
typename
scalar_t
,
int64_t
REDUCE
,
bool
HAS_VAL
>
__global__
void
__global__
void
spmm_kernel
(
const
int64_t
*
rowptr_data
,
const
int64_t
*
col_data
,
spmm_row_kernel
(
const
int64_t
*
rowptr
_data
,
const
int64
_t
*
col
_data
,
const
scalar_t
*
val
_data
,
const
scalar
_t
*
mat
_data
,
const
scalar_t
*
val
_data
,
const
scalar_t
*
mat_data
,
scalar_t
*
out
_data
,
int64_t
*
arg_out_data
,
size_t
N
,
scalar_t
*
out_data
,
size_t
N
,
size_t
K
)
{
size_t
K
)
{
// We ignore blockIdx.y here, because threads
across blockIdx.y operate on the
// We ignore blockIdx.y here, because threads
//
same row
.
//
across `blockIdx.y` are treated equally
.
int
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
warp_idx
=
thread_idx
>>
5
;
// thread_id / 32
int
row
=
thread_idx
>>
5
;
// thread_id / 32
int
lane_idx
=
thread_idx
&
(
32
-
1
);
// thread_id % 32
int
lane_idx
=
thread_idx
&
(
32
-
1
);
// thread_id % 32
int
row
=
warp_idx
;
// Each warp processes exactly one row.
// Compute the column index of `mat` in which the thread is operating.
// Compute the column index of `mat` in which the thread is operating.
int
mat_col_idx
=
lane_idx
+
(
blockIdx
.
y
<<
5
);
int
mat_col_idx
=
lane_idx
+
(
blockIdx
.
y
<<
5
);
...
@@ -29,9 +32,10 @@ spmm_row_kernel(const int64_t *rowptr_data, const int64_t *col_data,
...
@@ -29,9 +32,10 @@ spmm_row_kernel(const int64_t *rowptr_data, const int64_t *col_data,
int
out_idx
=
row
*
K
+
lane_idx
+
(
blockIdx
.
y
<<
5
);
int
out_idx
=
row
*
K
+
lane_idx
+
(
blockIdx
.
y
<<
5
);
// Helper arrays for warp communication.
// Helper arrays for warp communication.
int
mat_row
_all
[
Y_SIZE
];
int
mat_row
s
[
32
];
scalar_t
val
_all
[
Y_SIZE
];
scalar_t
val
s
[
32
];
// Do not aggregate/write across the Y-axis (lane_idx < leftover).
int
leftover
=
K
-
(
blockIdx
.
y
<<
5
);
int
leftover
=
K
-
(
blockIdx
.
y
<<
5
);
if
(
row
<
N
)
{
if
(
row
<
N
)
{
...
@@ -39,17 +43,27 @@ spmm_row_kernel(const int64_t *rowptr_data, const int64_t *col_data,
...
@@ -39,17 +43,27 @@ spmm_row_kernel(const int64_t *rowptr_data, const int64_t *col_data,
int
row_end
=
__ldg
(
rowptr_data
+
row
+
1
);
int
row_end
=
__ldg
(
rowptr_data
+
row
+
1
);
int
col_idx
=
row_start
+
lane_idx
;
int
col_idx
=
row_start
+
lane_idx
;
int
mat_row
=
-
1
;
int
mat_row
;
scalar_t
val
=
(
scalar_t
)
0
;
scalar_t
val
,
result
;
scalar_t
sum
=
(
scalar_t
)
0
;
int64_t
arg_result
=
-
1
;
// Iterate over all col indices in parallel with 32 threads.
// Dependent on `reduce`, we need to initialize `result` accordingly.
if
(
REDUCE
==
ADD
)
result
=
(
scalar_t
)
0
;
else
if
(
REDUCE
==
MEAN
)
result
=
(
scalar_t
)
0
;
else
if
(
REDUCE
==
MIN
)
result
=
std
::
numeric_limits
<
scalar_t
>::
max
();
else
if
(
REDUCE
==
MAX
)
result
=
std
::
numeric_limits
<
scalar_t
>::
min
();
// Iterate over all col indices in parallel within a warp.
for
(
int
c
=
row_start
;
c
<
row_end
;
c
+=
32
)
{
for
(
int
c
=
row_start
;
c
<
row_end
;
c
+=
32
)
{
if
(
col_idx
<
row_end
)
{
if
(
col_idx
<
row_end
)
{
// Coalesced memory access into `col` and `val`.
// Coalesced memory access into `col` and `val`.
mat_row
=
__ldg
(
col_data
+
col_idx
)
*
K
;
mat_row
=
__ldg
(
col_data
+
col_idx
)
*
K
;
val
=
__ldg
(
val_data
+
col_idx
);
val
=
HAS_VAL
?
__ldg
(
val_data
+
col_idx
)
:
(
scalar_t
)
1
;
}
else
{
}
else
{
mat_row
=
0
;
mat_row
=
0
;
val
=
(
scalar_t
)
0
;
val
=
(
scalar_t
)
0
;
...
@@ -57,50 +71,143 @@ spmm_row_kernel(const int64_t *rowptr_data, const int64_t *col_data,
...
@@ -57,50 +71,143 @@ spmm_row_kernel(const int64_t *rowptr_data, const int64_t *col_data,
col_idx
+=
32
;
col_idx
+=
32
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
32
;
i
+=
Y_SIZE
)
{
for
(
int
i
=
0
;
i
<
32
;
i
++
)
{
#pragma unroll
// Communication between all threads in a warp.
for
(
int
j
=
0
;
j
<
Y_SIZE
;
j
++
)
{
mat_rows
[
i
]
=
__shfl_sync
(
0xffffffff
,
mat_row
,
i
);
// Communication between *all* threads in a warp.
vals
[
i
]
=
__shfl_sync
(
0xffffffff
,
val
,
i
);
mat_row_all
[
j
]
=
__shfl_sync
(
0xffffffff
,
mat_row
,
i
+
j
);
}
val_all
[
j
]
=
__shfl_sync
(
0xffffffff
,
val
,
i
+
j
);
}
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
Y_SIZE
;
j
++
)
{
for
(
int
i
=
0
;
i
<
32
;
i
++
)
{
if
(
lane_idx
<
leftover
)
{
if
(
lane_idx
<
leftover
&&
vals
[
i
]
!=
0
)
{
// Coalesced memory access into `mat`.
// Coalesced memory access into `mat`.
sum
+=
val_all
[
j
]
*
__ldg
(
mat_data
+
mat_row_all
[
j
]
+
mat_col_idx
);
val
=
vals
[
i
]
*
__ldg
(
mat_data
+
mat_rows
[
i
]
+
mat_col_idx
);
// Aggregate results along row.
if
(
REDUCE
==
ADD
)
result
+=
val
;
else
if
(
REDUCE
==
MEAN
)
result
+=
val
;
else
if
(
REDUCE
==
MIN
)
{
if
(
val
<
result
)
{
result
=
val
;
arg_result
=
row_start
+
i
;
}
}
else
if
(
REDUCE
==
MAX
)
{
if
(
val
>
result
)
{
result
=
val
;
arg_result
=
row_start
+
i
;
}
}
}
}
}
}
}
}
}
if
(
lane_idx
<
leftover
)
{
if
(
lane_idx
<
leftover
)
{
// Coalesced memory access into `out`.
// Coalesced write into `out` (dependent on `reduce`).
out_data
[
out_idx
]
=
sum
;
if
(
REDUCE
==
ADD
)
out_data
[
out_idx
]
=
result
;
else
if
(
REDUCE
==
MEAN
)
out_data
[
out_idx
]
=
result
/
scalar_t
(
row_end
-
row_start
);
else
if
(
REDUCE
==
MIN
)
{
arg_out_data
[
out_idx
]
=
arg_result
;
if
(
result
==
std
::
numeric_limits
<
scalar_t
>::
max
())
out_data
[
out_idx
]
=
(
scalar_t
)
0
;
else
out_data
[
out_idx
]
=
result
;
}
else
if
(
REDUCE
==
MAX
)
{
arg_out_data
[
out_idx
]
=
arg_result
;
if
(
result
==
std
::
numeric_limits
<
scalar_t
>::
min
())
out_data
[
out_idx
]
=
(
scalar_t
)
0
;
else
out_data
[
out_idx
]
=
result
;
}
}
}
}
}
}
}
at
::
Tensor
spmm_cuda
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
Tensor
val
,
at
::
Tensor
spmm_cuda
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
Tensor
mat
)
{
at
::
optional
<
at
::
Tensor
>
val
,
at
::
Tensor
mat
,
auto
N
=
rowptr
.
numel
()
-
1
;
std
::
string
reduce
)
{
auto
N
=
rowptr
.
size
(
0
)
-
1
;
auto
K
=
mat
.
size
(
1
);
auto
K
=
mat
.
size
(
1
);
auto
out
=
at
::
empty
({
N
,
K
},
mat
.
options
());
auto
out
=
at
::
empty
({
N
,
K
},
mat
.
options
());
auto
rowptr_data
=
rowptr
.
DATA_PTR
<
int64_t
>
();
auto
rowptr_data
=
rowptr
.
DATA_PTR
<
int64_t
>
();
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
auto
block
_dim
=
dim3
(
THREADS
);
auto
block
=
dim3
(
THREADS
);
auto
grid
_dim
=
dim3
((
32
*
N
+
THREADS
-
1
)
/
THREADS
,
(
K
+
31
)
/
32
);
auto
grid
=
dim3
((
32
*
N
+
THREADS
-
1
)
/
THREADS
,
(
K
+
31
)
/
32
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_ALL_TYPES
(
mat
.
scalar_type
(),
"spmm_kernel"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES
(
mat
.
scalar_type
(),
"spmm_kernel"
,
[
&
]
{
auto
val_data
=
val
.
DATA_PTR
<
scalar_t
>
();
auto
mat_data
=
mat
.
DATA_PTR
<
scalar_t
>
();
auto
mat_data
=
mat
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
spmm_row_kernel
<
scalar_t
>
if
(
val
.
has_value
())
{
<<<
grid_dim
,
block_dim
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
auto
val_data
=
val
.
value
().
DATA_PTR
<
scalar_t
>
();
rowptr_data
,
col_data
,
val_data
,
mat_data
,
out_data
,
N
,
K
);
if
(
reduce
==
"add"
)
spmm_kernel
<
scalar_t
,
ADD
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
rowptr_data
,
col_data
,
val_data
,
mat_data
,
out_data
,
nullptr
,
N
,
K
);
else
if
(
reduce
==
"mean"
)
spmm_kernel
<
scalar_t
,
MEAN
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
rowptr_data
,
col_data
,
val_data
,
mat_data
,
out_data
,
nullptr
,
N
,
K
);
}
else
{
if
(
reduce
==
"add"
)
spmm_kernel
<
scalar_t
,
ADD
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
rowptr_data
,
col_data
,
nullptr
,
mat_data
,
out_data
,
nullptr
,
N
,
K
);
else
if
(
reduce
==
"mean"
)
spmm_kernel
<
scalar_t
,
MEAN
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
rowptr_data
,
col_data
,
nullptr
,
mat_data
,
out_data
,
nullptr
,
N
,
K
);
}
});
});
return
out
;
return
out
;
}
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
spmm_arg_cuda
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
optional
<
at
::
Tensor
>
val
,
at
::
Tensor
mat
,
std
::
string
reduce
)
{
auto
N
=
rowptr
.
size
(
0
)
-
1
;
auto
K
=
mat
.
size
(
1
);
auto
out
=
at
::
empty
({
N
,
K
},
mat
.
options
());
auto
arg_out
=
at
::
empty
({
N
,
K
},
rowptr
.
options
());
auto
rowptr_data
=
rowptr
.
DATA_PTR
<
int64_t
>
();
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
auto
arg_out_data
=
arg_out
.
DATA_PTR
<
int64_t
>
();
auto
block
=
dim3
(
THREADS
);
auto
grid
=
dim3
((
32
*
N
+
THREADS
-
1
)
/
THREADS
,
(
K
+
31
)
/
32
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_ALL_TYPES
(
mat
.
scalar_type
(),
"spmm_kernel"
,
[
&
]
{
auto
mat_data
=
mat
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
if
(
val
.
has_value
())
{
auto
val_data
=
val
.
value
().
DATA_PTR
<
scalar_t
>
();
if
(
reduce
==
"min"
)
spmm_kernel
<
scalar_t
,
MIN
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
rowptr_data
,
col_data
,
val_data
,
mat_data
,
out_data
,
arg_out_data
,
N
,
K
);
else
if
(
reduce
==
"max"
)
spmm_kernel
<
scalar_t
,
MAX
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
rowptr_data
,
col_data
,
val_data
,
mat_data
,
out_data
,
arg_out_data
,
N
,
K
);
}
else
{
if
(
reduce
==
"min"
)
spmm_kernel
<
scalar_t
,
MIN
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
rowptr_data
,
col_data
,
nullptr
,
mat_data
,
out_data
,
arg_out_data
,
N
,
K
);
else
if
(
reduce
==
"max"
)
spmm_kernel
<
scalar_t
,
MAX
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
rowptr_data
,
col_data
,
nullptr
,
mat_data
,
out_data
,
arg_out_data
,
N
,
K
);
}
});
return
std
::
make_tuple
(
out
,
arg_out
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment