Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
519306d3
"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "10dc06c8d982f745e54d8d9daff6b258726b8172"
Commit
519306d3
authored
Dec 17, 2019
by
rusty1s
Browse files
fast as fuck spmm kernel
parent
36d045fd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
25 deletions
+40
-25
cuda/spmm_kernel.cu
cuda/spmm_kernel.cu
+40
-25
No files found.
cuda/spmm_kernel.cu
View file @
519306d3
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include "compat.cuh"
#include "compat.cuh"
#define THREADS 32 * 16
#define Y_SIZE 32
#define THREADS 256
// Paper: Design Principles for Sparse Matrix Multiplication on the GPU
// Paper: Design Principles for Sparse Matrix Multiplication on the GPU
// Code: https://github.com/owensgroup/merge-spmm
// Code: https://github.com/owensgroup/merge-spmm
template
<
typename
scalar_t
,
size_t
Y_SIZE
>
template
<
typename
scalar_t
>
__global__
void
__global__
void
spmm_row_kernel
(
const
int64_t
*
rowptr_data
,
const
int64_t
*
col_data
,
spmm_row_kernel
(
const
int64_t
*
rowptr_data
,
const
int64_t
*
col_data
,
const
scalar_t
*
val_data
,
const
scalar_t
*
mat_data
,
const
scalar_t
*
val_data
,
const
scalar_t
*
mat_data
,
scalar_t
*
out_data
,
size_t
N
,
size_t
M
,
size_t
K
)
{
scalar_t
*
out_data
,
size_t
N
,
size_t
K
)
{
// We ignore blockIdx.y here, because threads across blockIdx.y operate on the
// We ignore blockIdx.y here, because threads across blockIdx.y operate on the
// same row.
// same row.
...
@@ -23,7 +25,7 @@ spmm_row_kernel(const int64_t *rowptr_data, const int64_t *col_data,
...
@@ -23,7 +25,7 @@ spmm_row_kernel(const int64_t *rowptr_data, const int64_t *col_data,
// Compute the column index of `mat` in which the thread is operating.
// Compute the column index of `mat` in which the thread is operating.
int
mat_col_idx
=
lane_idx
+
(
blockIdx
.
y
<<
5
);
int
mat_col_idx
=
lane_idx
+
(
blockIdx
.
y
<<
5
);
// Compute the output index
given in
row-major order.
// Compute the output index
(
row-major order
)
.
int
out_idx
=
row
*
K
+
lane_idx
+
(
blockIdx
.
y
<<
5
);
int
out_idx
=
row
*
K
+
lane_idx
+
(
blockIdx
.
y
<<
5
);
// Helper arrays for warp communication.
// Helper arrays for warp communication.
...
@@ -35,18 +37,30 @@ spmm_row_kernel(const int64_t *rowptr_data, const int64_t *col_data,
...
@@ -35,18 +37,30 @@ spmm_row_kernel(const int64_t *rowptr_data, const int64_t *col_data,
if
(
row
<
N
)
{
if
(
row
<
N
)
{
int
row_start
=
__ldg
(
rowptr_data
+
row
);
int
row_start
=
__ldg
(
rowptr_data
+
row
);
int
row_end
=
__ldg
(
rowptr_data
+
row
+
1
);
int
row_end
=
__ldg
(
rowptr_data
+
row
+
1
);
int
col_idx
=
row_start
+
lane_idx
;
int
mat_row
=
-
1
;
scalar_t
val
=
(
scalar_t
)
0
;
scalar_t
sum
=
(
scalar_t
)
0
;
// Iterate over all col indices in parallel with 32 threads.
for
(
int
c
=
row_start
;
c
<
row_end
;
c
+=
32
)
{
if
(
col_idx
<
row_end
)
{
// Coalesced memory access into `col` and `val`.
mat_row
=
__ldg
(
col_data
+
col_idx
)
*
K
;
val
=
__ldg
(
val_data
+
col_idx
);
}
else
{
mat_row
=
0
;
val
=
(
scalar_t
)
0
;
}
col_idx
+=
32
;
// Iterate over all col indices in parallel.
#pragma unroll
for
(
int
col_idx
=
row_start
+
lane_idx
;
col_idx
<
row_end
;
col_idx
+=
32
)
{
int
mat_row
=
__ldg
(
col_data
+
col_idx
)
*
K
;
int
val
=
__ldg
(
val_data
+
col_idx
);
scalar_t
sum
=
(
scalar_t
)
0
;
for
(
int
i
=
0
;
i
<
32
;
i
+=
Y_SIZE
)
{
for
(
int
i
=
0
;
i
<
32
;
i
+=
Y_SIZE
)
{
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
Y_SIZE
;
j
++
)
{
for
(
int
j
=
0
;
j
<
Y_SIZE
;
j
++
)
{
// Warp communication with *all* threads (mask = 0xffffffff).
// Communication between *all* threads in a warp.
// TODO: Compute real bit mask via `__ballot_sync()`.
mat_row_all
[
j
]
=
__shfl_sync
(
0xffffffff
,
mat_row
,
i
+
j
);
mat_row_all
[
j
]
=
__shfl_sync
(
0xffffffff
,
mat_row
,
i
+
j
);
val_all
[
j
]
=
__shfl_sync
(
0xffffffff
,
val
,
i
+
j
);
val_all
[
j
]
=
__shfl_sync
(
0xffffffff
,
val
,
i
+
j
);
}
}
...
@@ -58,34 +72,35 @@ spmm_row_kernel(const int64_t *rowptr_data, const int64_t *col_data,
...
@@ -58,34 +72,35 @@ spmm_row_kernel(const int64_t *rowptr_data, const int64_t *col_data,
}
}
}
}
}
}
if
(
lane_idx
<
leftover
)
{
}
out_data
[
out_idx
]
=
sum
;
if
(
lane_idx
<
leftover
)
{
}
// Coalesced memory access into `out`.
out_data
[
out_idx
]
=
sum
;
}
}
}
}
}
}
at
::
Tensor
spmm_cuda
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
Tensor
val
,
at
::
Tensor
spmm_cuda
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
Tensor
val
,
at
::
Tensor
mat
)
{
at
::
Tensor
mat
)
{
// TODO: Set device
auto
N
=
rowptr
.
numel
()
-
1
;
auto
N
=
rowptr
.
numel
()
-
1
;
auto
M
=
mat
.
size
(
0
);
auto
K
=
mat
.
size
(
1
);
auto
K
=
mat
.
size
(
1
);
auto
out
=
at
::
empty
({
N
,
K
},
mat
.
options
());
auto
out
=
at
::
empty
({
N
,
K
},
mat
.
options
());
auto
rowptr_data
=
rowptr
.
DATA_PTR
<
int64_t
>
();
auto
rowptr_data
=
rowptr
.
DATA_PTR
<
int64_t
>
();
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
auto
val_data
=
val
.
DATA_PTR
<
float
>
();
auto
mat_data
=
mat
.
DATA_PTR
<
float
>
();
auto
out_data
=
out
.
DATA_PTR
<
float
>
();
auto
block_dim
=
dim3
(
THREADS
);
auto
block_dim
=
dim3
(
THREADS
);
auto
grid_dim
=
dim3
((
N
+
THREADS
-
1
)
/
THREADS
,
(
K
+
32
-
1
)
/
32
);
auto
grid_dim
=
dim3
((
32
*
N
+
THREADS
-
1
)
/
THREADS
,
(
K
+
31
)
/
32
);
AT_DISPATCH_ALL_TYPES
(
mat
.
scalar_type
(),
"spmm_kernel"
,
[
&
]
{
auto
val_data
=
val
.
DATA_PTR
<
scalar_t
>
();
auto
mat_data
=
mat
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
spmm_row_kernel
<
float
,
32
><<<
grid_dim
,
block_dim
,
0
/*, cuda_stream */
>>>
(
spmm_row_kernel
<
scalar_t
>
rowptr_data
,
col_data
,
val_data
,
mat_data
,
out_data
,
N
,
M
,
K
);
<<<
grid_dim
,
block_dim
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
rowptr_data
,
col_data
,
val_data
,
mat_data
,
out_data
,
N
,
K
);
});
return
out
;
return
out
;
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment