Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
0fd716cb
Commit
0fd716cb
authored
Jan 23, 2020
by
rusty1s
Browse files
cuda spmm kernel
parent
bd49e20a
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
166 additions
and
181 deletions
+166
-181
cpu/spmm.cpp
cpu/spmm.cpp
+10
-9
cuda/spmm.cpp
cuda/spmm.cpp
+9
-25
cuda/spmm_kernel.cu
cuda/spmm_kernel.cu
+141
-140
setup.py
setup.py
+1
-1
test/test_matmul.py
test/test_matmul.py
+5
-6
No files found.
cpu/spmm.cpp
View file @
0fd716cb
...
@@ -88,20 +88,21 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
...
@@ -88,20 +88,21 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
std
::
tuple
<
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
std
::
tuple
<
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
spmm
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
optional
<
at
::
Tensor
>
value_opt
,
spmm
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
optional
<
at
::
Tensor
>
value_opt
,
at
::
Tensor
mat
,
std
::
string
reduce
)
{
at
::
Tensor
mat
,
std
::
string
reduce
)
{
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
col
);
CHECK_CPU
(
col
);
if
(
value_opt
.
has_value
())
if
(
value_opt
.
has_value
())
CHECK_CPU
(
value_opt
.
value
());
CHECK_CPU
(
value_opt
.
value
());
CHECK_CPU
(
mat
);
CHECK_CPU
(
mat
);
mat
=
mat
.
contiguous
();
AT_ASSERTM
(
rowptr
.
dim
()
==
1
,
"Input mismatch"
);
AT_ASSERTM
(
rowptr
.
dim
()
==
1
,
"Input mismatch"
);
AT_ASSERTM
(
col
.
dim
()
==
1
,
"Input mismatch"
);
AT_ASSERTM
(
col
.
dim
()
==
1
,
"Input mismatch"
);
if
(
value_opt
.
has_value
())
if
(
value_opt
.
has_value
())
AT_ASSERTM
(
value_opt
.
value
().
dim
()
==
1
);
AT_ASSERTM
(
value_opt
.
value
().
dim
()
==
1
);
AT_ASSERTM
(
mat
.
dim
()
>=
2
,
"Input mismatch"
);
AT_ASSERTM
(
mat
.
dim
()
>=
2
,
"Input mismatch"
);
mat
=
mat
.
contiguous
();
auto
sizes
=
mat
.
sizes
().
vec
();
auto
sizes
=
mat
.
sizes
().
vec
();
sizes
[
mat
.
dim
()
-
2
]
=
rowptr
.
numel
()
-
1
;
sizes
[
mat
.
dim
()
-
2
]
=
rowptr
.
numel
()
-
1
;
auto
out
=
at
::
empty
(
sizes
,
mat
.
options
());
auto
out
=
at
::
empty
(
sizes
,
mat
.
options
());
...
@@ -116,10 +117,10 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
...
@@ -116,10 +117,10 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
auto
rowptr_data
=
rowptr
.
DATA_PTR
<
int64_t
>
();
auto
rowptr_data
=
rowptr
.
DATA_PTR
<
int64_t
>
();
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
auto
N
=
rowptr
.
numel
()
-
1
;
auto
M
=
rowptr
.
numel
()
-
1
;
auto
M
=
mat
.
size
(
-
2
);
auto
N
=
mat
.
size
(
-
2
);
auto
K
=
mat
.
size
(
-
1
);
auto
K
=
mat
.
size
(
-
1
);
auto
B
=
mat
.
numel
()
/
(
M
*
K
);
auto
B
=
mat
.
numel
()
/
(
N
*
K
);
AT_DISPATCH_ALL_TYPES
(
mat
.
scalar_type
(),
"spmm"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES
(
mat
.
scalar_type
(),
"spmm"
,
[
&
]
{
scalar_t
*
value_data
=
nullptr
;
scalar_t
*
value_data
=
nullptr
;
...
@@ -138,13 +139,13 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
...
@@ -138,13 +139,13 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
}
}
for
(
int
b
=
0
;
b
<
B
;
b
++
)
{
for
(
int
b
=
0
;
b
<
B
;
b
++
)
{
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
m
=
0
;
m
<
M
;
m
++
)
{
row_start
=
rowptr_data
[
n
],
row_end
=
rowptr_data
[
n
+
1
];
row_start
=
rowptr_data
[
m
],
row_end
=
rowptr_data
[
m
+
1
];
for
(
int
k
=
0
;
k
<
K
;
k
++
)
for
(
int
k
=
0
;
k
<
K
;
k
++
)
vals
[
k
]
=
Reducer
<
scalar_t
,
REDUCE
>::
init
();
vals
[
k
]
=
Reducer
<
scalar_t
,
REDUCE
>::
init
();
int
offset
=
b
*
M
*
K
;
int
offset
=
b
*
N
*
K
;
for
(
int
e
=
row_start
;
e
<
row_end
;
e
++
)
{
for
(
int
e
=
row_start
;
e
<
row_end
;
e
++
)
{
c
=
col_data
[
e
];
c
=
col_data
[
e
];
if
(
HAS_VAL
)
if
(
HAS_VAL
)
...
@@ -159,7 +160,7 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
...
@@ -159,7 +160,7 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
&
vals
[
k
],
mat_data
[
offset
+
c
*
K
+
k
],
&
args
[
k
],
e
);
&
vals
[
k
],
mat_data
[
offset
+
c
*
K
+
k
],
&
args
[
k
],
e
);
}
}
}
}
offset
=
b
*
N
*
K
+
n
*
K
;
offset
=
b
*
M
*
K
+
m
*
K
;
for
(
int
k
=
0
;
k
<
K
;
k
++
)
for
(
int
k
=
0
;
k
<
K
;
k
++
)
Reducer
<
scalar_t
,
REDUCE
>::
write
(
out_data
+
offset
+
k
,
vals
[
k
],
Reducer
<
scalar_t
,
REDUCE
>::
write
(
out_data
+
offset
+
k
,
vals
[
k
],
arg_out_data
+
offset
+
k
,
arg_out_data
+
offset
+
k
,
...
...
cuda/spmm.cpp
View file @
0fd716cb
...
@@ -2,37 +2,21 @@
...
@@ -2,37 +2,21 @@
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at
::
Tensor
spmm_cuda
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
std
::
tuple
<
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
at
::
optional
<
at
::
Tensor
>
val
,
at
::
Tensor
mat
,
spmm_cuda
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
optional
<
at
::
Tensor
>
value_opt
,
std
::
string
reduce
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
spmm_arg_cuda
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
optional
<
at
::
Tensor
>
val
,
at
::
Tensor
mat
,
std
::
string
reduce
);
at
::
Tensor
mat
,
std
::
string
reduce
);
at
::
Tensor
spmm
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
optional
<
at
::
Tensor
>
val
,
std
::
tuple
<
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
spmm
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
optional
<
at
::
Tensor
>
value_opt
,
at
::
Tensor
mat
,
std
::
string
reduce
)
{
at
::
Tensor
mat
,
std
::
string
reduce
)
{
CHECK_CUDA
(
rowptr
);
CHECK_CUDA
(
rowptr
);
CHECK_CUDA
(
col
);
CHECK_CUDA
(
col
);
if
(
val
.
has_value
())
if
(
value_opt
.
has_value
())
CHECK_CUDA
(
val
.
value
());
CHECK_CUDA
(
value_opt
.
value
());
CHECK_CUDA
(
mat
);
return
spmm_cuda
(
rowptr
,
col
,
val
,
mat
,
reduce
);
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
spmm_arg
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
optional
<
at
::
Tensor
>
val
,
at
::
Tensor
mat
,
std
::
string
reduce
)
{
CHECK_CUDA
(
rowptr
);
CHECK_CUDA
(
col
);
if
(
val
.
has_value
())
CHECK_CUDA
(
val
.
value
());
CHECK_CUDA
(
mat
);
CHECK_CUDA
(
mat
);
return
spmm_
arg_
cuda
(
rowptr
,
col
,
val
,
mat
,
reduce
);
return
spmm_cuda
(
rowptr
,
col
,
val
ue_opt
,
mat
,
reduce
);
}
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"spmm"
,
&
spmm
,
"Sparse Matrix Multiplication (CUDA)"
);
m
.
def
(
"spmm"
,
&
spmm
,
"Sparse Matrix Multiplication (CUDA)"
);
m
.
def
(
"spmm_arg"
,
&
spmm_arg
,
"Sparse Matrix Multiplication With Arg (CUDA)"
);
}
}
cuda/spmm_kernel.cu
View file @
0fd716cb
...
@@ -4,68 +4,127 @@
...
@@ -4,68 +4,127 @@
#include "compat.cuh"
#include "compat.cuh"
#define THREADS 256
#define THREADS 256
#define FULL_MASK 0xffffffff
enum
ReductionType
{
SUM
,
MEAN
,
MIN
,
MAX
};
const
std
::
map
<
std
::
string
,
ReductionType
>
reduce2REDUCE
=
{
{
"sum"
,
SUM
},
{
"add"
,
SUM
},
{
"mean"
,
MEAN
},
{
"min"
,
MIN
},
{
"max"
,
MAX
},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
switch (reduce2REDUCE.at(reduce)) { \
case SUM: { \
const ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \
} \
case MEAN: { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} \
case MIN: { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} \
case MAX: { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
} \
}()
template
<
typename
scalar_t
,
ReductionType
REDUCE
>
struct
Reducer
{
static
inline
__host__
__device__
scalar_t
init
()
{
if
(
REDUCE
==
MIN
)
{
return
std
::
numeric_limits
<
scalar_t
>::
max
();
}
else
if
(
REDUCE
==
MAX
)
{
return
std
::
numeric_limits
<
scalar_t
>::
lowest
();
}
else
{
return
(
scalar_t
)
0
;
}
}
#define ADD 0
static
inline
__host__
__device__
void
update
(
scalar_t
*
val
,
scalar_t
new_val
,
#define MEAN 1
int64_t
*
arg
,
int64_t
new_arg
)
{
#define MIN 2
if
(
REDUCE
==
SUM
||
REDUCE
==
MEAN
)
{
#define MAX 3
*
val
=
*
val
+
new_val
;
}
else
if
((
REDUCE
==
MIN
&&
new_val
<
*
val
)
||
(
REDUCE
==
MAX
&&
new_val
>
*
val
))
{
*
val
=
new_val
;
*
arg
=
new_arg
;
}
}
static
inline
__host__
__device__
void
write
(
scalar_t
*
address
,
scalar_t
val
,
int64_t
*
arg_address
,
int64_t
arg
,
int
count
)
{
if
(
REDUCE
==
SUM
)
{
*
address
=
val
;
}
else
if
(
REDUCE
==
MEAN
)
{
*
address
=
val
/
(
scalar_t
)
max
(
count
,
1
);
}
else
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
{
if
(
count
>
0
)
{
*
address
=
val
;
*
arg_address
=
arg
;
}
else
{
*
address
=
(
scalar_t
)
0
;
}
}
}
};
// Paper: Design Principles for Sparse Matrix Multiplication on the GPU
// Paper: Design Principles for Sparse Matrix Multiplication on the GPU
// Code: https://github.com/owensgroup/merge-spmm
// Code: https://github.com/owensgroup/merge-spmm
template
<
typename
scalar_t
,
int64_t
REDUCE
,
bool
HAS_VAL
>
template
<
typename
scalar_t
,
ReductionType
REDUCE
,
bool
HAS_VAL
>
__global__
void
spmm_kernel
(
const
int64_t
*
rowptr_data
,
const
int64_t
*
col_data
,
__global__
void
spmm_kernel
(
const
int64_t
*
rowptr_data
,
const
int64_t
*
col_data
,
const
scalar_t
*
val
_data
,
const
scalar_t
*
mat
_data
,
const
scalar_t
*
val
ue
_data
,
scalar_t
*
ou
t_data
,
int64_t
*
arg_out_data
,
size_t
N
,
const
scalar_t
*
ma
t_data
,
scalar_t
*
out_data
,
size_
t
K
)
{
int64_t
*
arg_out_data
,
int
B
,
int
M
,
int
N
,
in
t
K
)
{
// We ignore blockIdx.y here, because threads
// We ignore blockIdx.y here, because threads
// across `blockIdx.y` are treated equally.
// across `blockIdx.y` are treated equally.
int
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
row
=
thread_idx
>>
5
;
// thread_id / 32
int
row
=
thread_idx
>>
5
;
// thread_idx / 32
int
lane_idx
=
thread_idx
&
(
32
-
1
);
// thread_id % 32
int
lane_idx
=
thread_idx
&
(
32
-
1
);
// thread_idx % 32
int
batch_idx
=
row
/
M
;
// Compute the column index of `mat` in which the thread is operating.
// Compute the column index of `mat` in which the thread is operating.
int
mat_col_idx
=
lane_idx
+
(
blockIdx
.
y
<<
5
);
int
mat_col_idx
=
lane_idx
+
(
blockIdx
.
y
<<
5
);
// Compute the output index (row-major order).
// Compute the output index (row-major order).
int
out_idx
=
row
*
K
+
lane_idx
+
(
blockIdx
.
y
<<
5
)
;
int
out_idx
=
row
*
K
+
mat_col_idx
;
// Helper arrays for warp communication.
// Helper arrays for warp communication.
int
mat_rows
[
32
];
int
mat_row
,
mat_rows
[
32
];
scalar_t
vals
[
32
];
scalar_t
val
,
vals
[
HAS_VAL
?
32
:
1
];
int
bla
,
blas
[
32
];
// Do not aggregate/write across the Y-axis (lane_idx < leftover).
// Do not aggregate/write across the Y-axis (lane_idx < leftover).
int
leftover
=
K
-
(
blockIdx
.
y
<<
5
);
int
leftover
=
K
-
(
blockIdx
.
y
<<
5
);
if
(
row
<
N
)
{
if
(
row
<
B
*
M
)
{
int
row_start
=
__ldg
(
rowptr_data
+
row
);
int
row_start
=
__ldg
(
rowptr_data
+
(
row
%
M
)
);
int
row_end
=
__ldg
(
rowptr_data
+
row
+
1
);
int
row_end
=
__ldg
(
rowptr_data
+
(
row
%
M
)
+
1
);
int
col_idx
=
row_start
+
lane_idx
;
int
col_idx
=
row_start
+
lane_idx
;
int
mat_row
;
scalar_t
result
=
Reducer
<
scalar_t
,
REDUCE
>::
init
();
scalar_t
val
,
result
;
int64_t
arg
;
int64_t
arg_result
=
-
1
;
// Iterate over all `col` indices in parallel within a warp.
// Dependent on `reduce`, we need to initialize `result` accordingly.
if
(
REDUCE
==
ADD
)
result
=
(
scalar_t
)
0
;
else
if
(
REDUCE
==
MEAN
)
result
=
(
scalar_t
)
0
;
else
if
(
REDUCE
==
MIN
)
result
=
std
::
numeric_limits
<
scalar_t
>::
max
();
else
if
(
REDUCE
==
MAX
)
result
=
std
::
numeric_limits
<
scalar_t
>::
min
();
// Iterate over all col indices in parallel within a warp.
for
(
int
c
=
row_start
;
c
<
row_end
;
c
+=
32
)
{
for
(
int
c
=
row_start
;
c
<
row_end
;
c
+=
32
)
{
if
(
col_idx
<
row_end
)
{
if
(
col_idx
<
row_end
)
{
// Coalesced memory access into `col` and `val`.
// Coalesced memory access into `col` and `val`.
mat_row
=
__ldg
(
col_data
+
col_idx
)
*
K
;
mat_row
=
__ldg
(
col_data
+
col_idx
)
*
K
;
val
=
HAS_VAL
?
__ldg
(
val_data
+
col_idx
)
:
(
scalar_t
)
1
;
bla
=
col_idx
;
if
(
HAS_VAL
)
val
=
__ldg
(
value_data
+
col_idx
);
}
else
{
}
else
{
mat_row
=
0
;
mat_row
=
-
1
;
bla
=
-
1
;
if
(
HAS_VAL
)
val
=
(
scalar_t
)
0
;
val
=
(
scalar_t
)
0
;
}
}
col_idx
+=
32
;
col_idx
+=
32
;
...
@@ -73,141 +132,83 @@ __global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
...
@@ -73,141 +132,83 @@ __global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
32
;
i
++
)
{
for
(
int
i
=
0
;
i
<
32
;
i
++
)
{
// Communication between all threads in a warp.
// Communication between all threads in a warp.
mat_rows
[
i
]
=
__shfl_sync
(
0xffffffff
,
mat_row
,
i
);
mat_rows
[
i
]
=
__shfl_sync
(
FULL_MASK
,
mat_row
,
i
);
vals
[
i
]
=
__shfl_sync
(
0xffffffff
,
val
,
i
);
blas
[
i
]
=
__shfl_sync
(
FULL_MASK
,
bla
,
i
);
if
(
HAS_VAL
)
vals
[
i
]
=
__shfl_sync
(
FULL_MASK
,
val
,
i
);
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
32
;
i
++
)
{
for
(
int
i
=
0
;
i
<
32
;
i
++
)
{
if
(
lane_idx
<
leftover
&&
val
s
[
i
]
!=
0
)
{
if
(
lane_idx
<
leftover
&&
mat_row
s
[
i
]
!=
-
1
)
{
// Coalesced memory access into `mat`.
// Coalesced memory access into `mat`.
val
=
vals
[
i
]
*
__ldg
(
mat_data
+
mat_rows
[
i
]
+
mat_col_idx
);
val
=
__ldg
(
mat_data
+
batch_idx
*
N
*
K
+
mat_rows
[
i
]
+
mat_col_idx
);
if
(
HAS_VAL
)
// Aggregate results along row.
val
=
vals
[
i
]
*
val
;
if
(
REDUCE
==
ADD
)
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
result
,
val
,
&
arg
,
c
+
i
);
result
+=
val
;
else
if
(
REDUCE
==
MEAN
)
result
+=
val
;
else
if
(
REDUCE
==
MIN
)
{
if
(
val
<
result
)
{
result
=
val
;
arg_result
=
row_start
+
i
;
}
}
else
if
(
REDUCE
==
MAX
)
{
if
(
val
>
result
)
{
result
=
val
;
arg_result
=
row_start
+
i
;
}
}
}
}
}
}
}
}
if
(
lane_idx
<
leftover
)
{
if
(
lane_idx
<
leftover
)
{
// Coalesced write into `out` (dependent on `reduce`).
// Coalesced write into `out`.
if
(
REDUCE
==
ADD
)
Reducer
<
scalar_t
,
REDUCE
>::
write
(
out_data
+
out_idx
,
result
,
out_data
[
out_idx
]
=
result
;
arg_out_data
+
out_idx
,
arg
,
else
if
(
REDUCE
==
MEAN
)
row_end
-
row_start
);
out_data
[
out_idx
]
=
result
/
scalar_t
(
row_end
-
row_start
);
else
if
(
REDUCE
==
MIN
)
{
arg_out_data
[
out_idx
]
=
arg_result
;
if
(
result
==
std
::
numeric_limits
<
scalar_t
>::
max
())
out_data
[
out_idx
]
=
(
scalar_t
)
0
;
else
out_data
[
out_idx
]
=
result
;
}
else
if
(
REDUCE
==
MAX
)
{
arg_out_data
[
out_idx
]
=
arg_result
;
if
(
result
==
std
::
numeric_limits
<
scalar_t
>::
min
())
out_data
[
out_idx
]
=
(
scalar_t
)
0
;
else
out_data
[
out_idx
]
=
result
;
}
}
}
}
}
}
}
at
::
Tensor
spmm_cuda
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
std
::
tuple
<
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
at
::
optional
<
at
::
Tensor
>
val
,
at
::
Tensor
mat
,
spmm_cuda
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
optional
<
at
::
Tensor
>
value_opt
,
std
::
string
reduce
)
{
at
::
Tensor
mat
,
std
::
string
reduce
)
{
auto
N
=
rowptr
.
size
(
0
)
-
1
;
auto
K
=
mat
.
size
(
1
);
auto
out
=
at
::
empty
({
N
,
K
},
mat
.
options
());
auto
rowptr_data
=
rowptr
.
DATA_PTR
<
int64_t
>
();
AT_ASSERTM
(
rowptr
.
dim
()
==
1
,
"Input mismatch"
);
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
AT_ASSERTM
(
col
.
dim
()
==
1
,
"Input mismatch"
);
if
(
value_opt
.
has_value
())
AT_ASSERTM
(
value_opt
.
value
().
dim
()
==
1
);
AT_ASSERTM
(
mat
.
dim
()
>=
2
,
"Input mismatch"
);
auto
block
=
dim3
(
THREADS
);
mat
=
mat
.
contiguous
();
auto
grid
=
dim3
((
32
*
N
+
THREADS
-
1
)
/
THREADS
,
(
K
+
31
)
/
32
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
sizes
=
mat
.
sizes
().
vec
();
AT_DISPATCH_ALL_TYPES
(
mat
.
scalar_type
(),
"spmm_kernel"
,
[
&
]
{
sizes
[
mat
.
dim
()
-
2
]
=
rowptr
.
numel
()
-
1
;
auto
mat_data
=
mat
.
DATA_PTR
<
scalar_t
>
();
auto
out
=
at
::
empty
(
sizes
,
mat
.
options
());
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
if
(
val
.
has_value
())
{
at
::
optional
<
at
::
Tensor
>
arg_out
=
at
::
nullopt
;
auto
val_data
=
val
.
value
().
DATA_PTR
<
scalar_t
>
();
int64_t
*
arg_out_data
=
nullptr
;
if
(
reduce
==
"add"
)
if
(
reduce2REDUCE
.
at
(
reduce
)
==
MIN
||
reduce2REDUCE
.
at
(
reduce
)
==
MAX
)
{
spmm_kernel
<
scalar_t
,
ADD
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
arg_out
=
at
::
full_like
(
out
,
col
.
numel
(),
rowptr
.
options
());
rowptr_data
,
col_data
,
val_data
,
mat_data
,
out_data
,
nullptr
,
N
,
K
);
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
else
if
(
reduce
==
"mean"
)
spmm_kernel
<
scalar_t
,
MEAN
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
rowptr_data
,
col_data
,
val_data
,
mat_data
,
out_data
,
nullptr
,
N
,
K
);
}
else
{
if
(
reduce
==
"add"
)
spmm_kernel
<
scalar_t
,
ADD
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
rowptr_data
,
col_data
,
nullptr
,
mat_data
,
out_data
,
nullptr
,
N
,
K
);
else
if
(
reduce
==
"mean"
)
spmm_kernel
<
scalar_t
,
MEAN
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
rowptr_data
,
col_data
,
nullptr
,
mat_data
,
out_data
,
nullptr
,
N
,
K
);
}
}
});
return
out
;
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
spmm_arg_cuda
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
optional
<
at
::
Tensor
>
val
,
at
::
Tensor
mat
,
std
::
string
reduce
)
{
auto
N
=
rowptr
.
size
(
0
)
-
1
;
auto
K
=
mat
.
size
(
1
);
auto
out
=
at
::
empty
({
N
,
K
},
mat
.
options
());
auto
arg_out
=
at
::
empty
({
N
,
K
},
rowptr
.
options
());
auto
rowptr_data
=
rowptr
.
DATA_PTR
<
int64_t
>
();
auto
rowptr_data
=
rowptr
.
DATA_PTR
<
int64_t
>
();
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
auto
arg_out_data
=
arg_out
.
DATA_PTR
<
int64_t
>
();
auto
block
=
dim3
(
THREADS
);
auto
M
=
rowptr
.
numel
()
-
1
;
auto
grid
=
dim3
((
32
*
N
+
THREADS
-
1
)
/
THREADS
,
(
K
+
31
)
/
32
);
auto
N
=
mat
.
size
(
-
2
);
auto
K
=
mat
.
size
(
-
1
);
auto
B
=
mat
.
numel
()
/
(
N
*
K
);
auto
BLOCKS
=
dim3
((
32
*
B
*
M
+
THREADS
-
1
)
/
THREADS
,
(
K
+
31
)
/
32
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_ALL_TYPES
(
mat
.
scalar_type
(),
"spmm_kernel"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES
(
mat
.
scalar_type
(),
"spmm_kernel"
,
[
&
]
{
auto
mat_data
=
mat
.
DATA_PTR
<
scalar_t
>
();
auto
mat_data
=
mat
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
if
(
val
.
has_value
())
{
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
auto
val_data
=
val
.
value
().
DATA_PTR
<
scalar_t
>
();
if
(
value_opt
.
has_value
())
{
if
(
reduce
==
"min"
)
auto
value_data
=
value_opt
.
value
().
DATA_PTR
<
scalar_t
>
();
spmm_kernel
<
scalar_t
,
MIN
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
spmm_kernel
<
scalar_t
,
REDUCE
,
true
><<<
BLOCKS
,
THREADS
,
0
,
stream
>>>
(
rowptr_data
,
col_data
,
val_data
,
mat_data
,
out_data
,
arg_out_data
,
rowptr_data
,
col_data
,
value_data
,
mat_data
,
out_data
,
arg_out_data
,
N
,
K
);
B
,
M
,
N
,
K
);
else
if
(
reduce
==
"max"
)
spmm_kernel
<
scalar_t
,
MAX
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
rowptr_data
,
col_data
,
val_data
,
mat_data
,
out_data
,
arg_out_data
,
N
,
K
);
}
else
{
}
else
{
if
(
reduce
==
"min"
)
spmm_kernel
<
scalar_t
,
REDUCE
,
false
><<<
BLOCKS
,
THREADS
,
0
,
stream
>>>
(
spmm_kernel
<
scalar_t
,
MIN
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
rowptr_data
,
col_data
,
nullptr
,
mat_data
,
out_data
,
arg_out_data
,
B
,
rowptr_data
,
col_data
,
nullptr
,
mat_data
,
out_data
,
arg_out_data
,
N
,
M
,
N
,
K
);
K
);
else
if
(
reduce
==
"max"
)
spmm_kernel
<
scalar_t
,
MAX
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
rowptr_data
,
col_data
,
nullptr
,
mat_data
,
out_data
,
arg_out_data
,
N
,
K
);
}
}
});
});
});
return
std
::
make_tuple
(
out
,
arg_out
);
return
std
::
make_tuple
(
out
,
arg_out
);
}
}
setup.py
View file @
0fd716cb
...
@@ -8,7 +8,7 @@ import torch
...
@@ -8,7 +8,7 @@ import torch
from
torch.utils.cpp_extension
import
CppExtension
,
CUDAExtension
,
CUDA_HOME
from
torch.utils.cpp_extension
import
CppExtension
,
CUDAExtension
,
CUDA_HOME
cxx_extra_compile_args
=
[]
cxx_extra_compile_args
=
[]
nvcc_extra_compile_args
=
[]
nvcc_extra_compile_args
=
[
'-arch=sm_35'
,
'--expt-relaxed-constexpr'
]
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
...
...
test/test_matmul.py
View file @
0fd716cb
...
@@ -9,19 +9,18 @@ import torch_scatter
...
@@ -9,19 +9,18 @@ import torch_scatter
from
.utils
import
devices
,
grad_dtypes
from
.utils
import
devices
,
grad_dtypes
devices
=
[
'cpu'
]
devices
=
[
'cpu'
,
'cuda'
]
grad_dtypes
=
[
torch
.
float
]
grad_dtypes
=
[
torch
.
float
]
reductions
=
[
'sum'
,
'mean'
,
'min'
,
'max'
]
reductions
=
[
'sum'
,
'mean'
,
'min'
,
'max'
]
reductions
=
[
'min'
]
reductions
=
[
'min'
,
'max'
]
@
pytest
.
mark
.
parametrize
(
'dtype,device,reduce'
,
@
pytest
.
mark
.
parametrize
(
'dtype,device,reduce'
,
product
(
grad_dtypes
,
devices
,
reductions
))
product
(
grad_dtypes
,
devices
,
reductions
))
def
test_spmm
(
dtype
,
device
,
reduce
):
def
test_spmm
(
dtype
,
device
,
reduce
):
src
=
torch
.
randn
((
10
,
8
),
dtype
=
dtype
,
device
=
device
)
src
=
torch
.
randn
((
10
,
8
),
dtype
=
dtype
,
device
=
device
)
src
[
2
,
:]
=
0
#
Delete on
e row
..
.
src
[
2
:
4
,
:]
=
0
#
Remove multipl
e row
s
.
src
[:,
2
:
4
]
=
0
#
Delete on
e col
..
.
src
[:,
2
:
4
]
=
0
#
Remove multipl
e col
umns
.
src
=
SparseTensor
.
from_dense
(
src
).
requires_grad_
()
src
=
SparseTensor
.
from_dense
(
src
).
requires_grad_
()
(
row
,
col
),
value
=
src
.
coo
()
(
row
,
col
),
value
=
src
.
coo
()
...
@@ -35,7 +34,7 @@ def test_spmm(dtype, device, reduce):
...
@@ -35,7 +34,7 @@ def test_spmm(dtype, device, reduce):
if
reduce
==
'min'
:
if
reduce
==
'min'
:
expected
[
expected
>
1000
]
=
0
expected
[
expected
>
1000
]
=
0
if
reduce
==
'max'
:
if
reduce
==
'max'
:
expected
[
expected
<
1000
]
=
0
expected
[
expected
<
-
1000
]
=
0
grad_out
=
torch
.
randn_like
(
expected
)
grad_out
=
torch
.
randn_like
(
expected
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment