Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
6a7f10e5
Commit
6a7f10e5
authored
Jan 23, 2020
by
rusty1s
Browse files
matmul complete
parent
0fd716cb
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
134 additions
and
31 deletions
+134
-31
cpu/spmm.cpp
cpu/spmm.cpp
+28
-18
cuda/spmm.cpp
cuda/spmm.cpp
+14
-0
cuda/spmm_kernel.cu
cuda/spmm_kernel.cu
+84
-8
test/test_matmul.py
test/test_matmul.py
+0
-3
torch_sparse/matmul.py
torch_sparse/matmul.py
+3
-2
torch_sparse/tensor.py
torch_sparse/tensor.py
+5
-0
No files found.
cpu/spmm.cpp
View file @
6a7f10e5
...
...
@@ -174,46 +174,56 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
return
std
::
make_tuple
(
out
,
arg_out
);
}
at
::
Tensor
spmm_val_bw
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
Tensor
mat
,
at
::
Tensor
spmm_val_bw
(
at
::
Tensor
index
,
at
::
Tensor
rowptr
,
at
::
Tensor
mat
,
at
::
Tensor
grad
,
std
::
string
reduce
)
{
CHECK_CPU
(
index
);
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
col
);
CHECK_CPU
(
mat
);
CHECK_CPU
(
grad
);
AT_ASSERTM
(
index
.
dim
()
==
2
,
"Input mismatch"
);
AT_ASSERTM
(
index
.
size
(
0
)
==
2
,
"Input mismatch"
);
AT_ASSERTM
(
rowptr
.
dim
()
==
1
,
"Input mismatch"
);
AT_ASSERTM
(
mat
.
dim
()
>=
2
,
"Input mismatch"
);
AT_ASSERTM
(
mat
.
dim
()
==
grad
.
dim
(),
"Input mismatch"
);
AT_ASSERTM
(
reduce2REDUCE
.
at
(
reduce
)
==
SUM
||
reduce2REDUCE
.
at
(
reduce
)
==
MEAN
,
"Reduce operation not supported"
);
index
=
index
.
contiguous
();
mat
=
mat
.
contiguous
();
grad
=
grad
.
contiguous
();
auto
M
=
rowptr
.
numel
()
-
1
;
auto
M
=
grad
.
size
(
-
2
)
;
auto
N
=
mat
.
size
(
-
2
);
auto
E
=
index
.
size
(
1
);
auto
K
=
mat
.
size
(
-
1
);
auto
B
=
mat
.
numel
()
/
(
N
*
K
);
auto
out
=
at
::
zeros
(
col
.
size
s
(),
grad
.
options
());
auto
out
=
at
::
zeros
(
index
.
size
(
1
),
grad
.
options
());
auto
index_data
=
index
.
DATA_PTR
<
int64_t
>
();
auto
rowptr_data
=
rowptr
.
DATA_PTR
<
int64_t
>
();
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
AT_DISPATCH_ALL_TYPES
(
mat
.
scalar_type
(),
"spmm_val_bw"
,
[
&
]
{
auto
mat_data
=
mat
.
DATA_PTR
<
scalar_t
>
();
auto
grad_data
=
grad
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
scalar_t
val
;
int64_t
row
_start
,
row_end
,
c
;
int64_t
row
,
c
ol
;
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
for
(
int
b
=
0
;
b
<
B
;
b
++
)
{
for
(
int
m
=
0
;
m
<
M
;
m
++
)
{
row_start
=
rowptr_data
[
m
],
row_end
=
rowptr_data
[
m
+
1
];
for
(
int
e
=
row_start
;
e
<
row_end
;
e
++
)
{
c
=
col_data
[
e
],
val
=
(
scalar_t
)
0
;
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
val
+=
mat_data
[
b
*
N
*
K
+
c
*
K
+
k
]
*
grad_data
[
b
*
M
*
K
+
m
*
K
+
k
];
}
if
(
REDUCE
==
MEAN
)
val
=
val
/
(
scalar_t
)(
row_end
-
row_start
);
out_data
[
e
]
+=
val
;
for
(
int
e
=
0
;
e
<
E
;
e
++
)
{
row
=
index_data
[
e
],
col
=
index_data
[
E
+
e
],
val
=
(
scalar_t
)
0
;
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
val
+=
mat_data
[
b
*
N
*
K
+
col
*
K
+
k
]
*
grad_data
[
b
*
M
*
K
+
row
*
K
+
k
];
}
if
(
REDUCE
==
MEAN
)
{
int
row_start
=
rowptr_data
[
row
],
row_end
=
rowptr_data
[
row
+
1
];
val
/=
(
scalar_t
)
std
::
max
(
row_end
-
row_start
,
1
);
}
out_data
[
e
]
+=
val
;
}
}
});
...
...
cuda/spmm.cpp
View file @
6a7f10e5
...
...
@@ -6,6 +6,9 @@ std::tuple<at::Tensor, at::optional<at::Tensor>>
spmm_cuda
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
optional
<
at
::
Tensor
>
value_opt
,
at
::
Tensor
mat
,
std
::
string
reduce
);
at
::
Tensor
spmm_val_bw_cuda
(
at
::
Tensor
index
,
at
::
Tensor
rowptr
,
at
::
Tensor
mat
,
at
::
Tensor
grad
,
std
::
string
reduce
);
std
::
tuple
<
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
spmm
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
optional
<
at
::
Tensor
>
value_opt
,
at
::
Tensor
mat
,
std
::
string
reduce
)
{
...
...
@@ -17,6 +20,17 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
return
spmm_cuda
(
rowptr
,
col
,
value_opt
,
mat
,
reduce
);
}
at
::
Tensor
spmm_val_bw
(
at
::
Tensor
index
,
at
::
Tensor
rowptr
,
at
::
Tensor
mat
,
at
::
Tensor
grad
,
std
::
string
reduce
)
{
CHECK_CUDA
(
index
);
CHECK_CUDA
(
rowptr
);
CHECK_CUDA
(
mat
);
CHECK_CUDA
(
grad
);
return
spmm_val_bw_cuda
(
index
,
rowptr
,
mat
,
grad
,
reduce
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"spmm"
,
&
spmm
,
"Sparse Matrix Multiplication (CUDA)"
);
m
.
def
(
"spmm_val_bw"
,
&
spmm_val_bw
,
"Sparse-Dense Matrix Multiplication Value Backward (CPU)"
);
}
cuda/spmm_kernel.cu
View file @
6a7f10e5
...
...
@@ -99,12 +99,11 @@ __global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
// Helper arrays for warp communication.
int
mat_row
,
mat_rows
[
32
];
scalar_t
val
,
vals
[
HAS_VAL
?
32
:
1
];
int
bla
,
blas
[
32
];
// Do not aggregate/write across the Y-axis (lane_idx < leftover).
int
leftover
=
K
-
(
blockIdx
.
y
<<
5
);
if
(
row
<
B
*
M
)
{
if
(
batch_idx
<
B
)
{
int
row_start
=
__ldg
(
rowptr_data
+
(
row
%
M
));
int
row_end
=
__ldg
(
rowptr_data
+
(
row
%
M
)
+
1
);
int
col_idx
=
row_start
+
lane_idx
;
...
...
@@ -118,12 +117,10 @@ __global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
if
(
col_idx
<
row_end
)
{
// Coalesced memory access into `col` and `val`.
mat_row
=
__ldg
(
col_data
+
col_idx
)
*
K
;
bla
=
col_idx
;
if
(
HAS_VAL
)
val
=
__ldg
(
value_data
+
col_idx
);
}
else
{
mat_row
=
-
1
;
bla
=
-
1
;
if
(
HAS_VAL
)
val
=
(
scalar_t
)
0
;
}
...
...
@@ -133,7 +130,6 @@ __global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
for
(
int
i
=
0
;
i
<
32
;
i
++
)
{
// Communication between all threads in a warp.
mat_rows
[
i
]
=
__shfl_sync
(
FULL_MASK
,
mat_row
,
i
);
blas
[
i
]
=
__shfl_sync
(
FULL_MASK
,
bla
,
i
);
if
(
HAS_VAL
)
vals
[
i
]
=
__shfl_sync
(
FULL_MASK
,
val
,
i
);
}
...
...
@@ -182,9 +178,6 @@ spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
}
auto
rowptr_data
=
rowptr
.
DATA_PTR
<
int64_t
>
();
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
auto
M
=
rowptr
.
numel
()
-
1
;
auto
N
=
mat
.
size
(
-
2
);
auto
K
=
mat
.
size
(
-
1
);
...
...
@@ -193,6 +186,8 @@ spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_ALL_TYPES
(
mat
.
scalar_type
(),
"spmm_kernel"
,
[
&
]
{
auto
rowptr_data
=
rowptr
.
DATA_PTR
<
int64_t
>
();
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
auto
mat_data
=
mat
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
...
...
@@ -212,3 +207,84 @@ spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
return
std
::
make_tuple
(
out
,
arg_out
);
}
template
<
typename
scalar_t
,
ReductionType
REDUCE
>
__global__
void
spmm_val_bw_kernel
(
const
int64_t
*
index_data
,
const
int64_t
*
rowptr_data
,
const
scalar_t
*
mat_data
,
const
scalar_t
*
grad_data
,
scalar_t
*
out_data
,
int
B
,
int
M
,
int
N
,
int
E
,
int
K
)
{
int
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
index_idx
=
(
thread_idx
>>
5
);
// thread_idx / 32
int
lane_idx
=
thread_idx
&
(
32
-
1
);
// thread_idx % 32
if
(
index_idx
<
E
)
{
int
row
=
__ldg
(
index_data
+
index_idx
);
int
col
=
__ldg
(
index_data
+
E
+
index_idx
);
scalar_t
val
=
(
scalar_t
)
0
;
for
(
int
b
=
0
;
b
<
B
;
b
++
)
{
for
(
int
k
=
lane_idx
;
k
<
K
;
k
+=
32
)
{
val
+=
mat_data
[
b
*
N
*
K
+
col
*
K
+
k
]
*
grad_data
[
b
*
M
*
K
+
row
*
K
+
k
];
}
}
#pragma unroll
for
(
int
i
=
32
/
2
;
i
>
0
;
i
/=
2
)
{
// Parallel reduction inside a warp.
val
+=
__shfl_down_sync
(
FULL_MASK
,
val
,
i
);
}
if
(
lane_idx
==
0
)
{
if
(
REDUCE
==
MEAN
)
{
int
row_start
=
__ldg
(
rowptr_data
+
row
);
int
row_end
=
__ldg
(
rowptr_data
+
row
+
1
);
val
/=
(
scalar_t
)
max
(
row_end
-
row_start
,
1
);
}
out_data
[
index_idx
]
=
val
;
}
}
}
at
::
Tensor
spmm_val_bw_cuda
(
at
::
Tensor
index
,
at
::
Tensor
rowptr
,
at
::
Tensor
mat
,
at
::
Tensor
grad
,
std
::
string
reduce
)
{
AT_ASSERTM
(
index
.
dim
()
==
2
,
"Input mismatch"
);
AT_ASSERTM
(
index
.
size
(
0
)
==
2
,
"Input mismatch"
);
AT_ASSERTM
(
rowptr
.
dim
()
==
1
,
"Input mismatch"
);
AT_ASSERTM
(
mat
.
dim
()
>=
2
,
"Input mismatch"
);
AT_ASSERTM
(
mat
.
dim
()
==
grad
.
dim
(),
"Input mismatch"
);
AT_ASSERTM
(
reduce2REDUCE
.
at
(
reduce
)
==
SUM
||
reduce2REDUCE
.
at
(
reduce
)
==
MEAN
,
"Reduce operation not supported"
);
index
=
index
.
contiguous
();
mat
=
mat
.
contiguous
();
grad
=
grad
.
contiguous
();
auto
M
=
grad
.
size
(
-
2
);
auto
N
=
mat
.
size
(
-
2
);
auto
E
=
index
.
size
(
1
);
auto
K
=
mat
.
size
(
-
1
);
auto
B
=
mat
.
numel
()
/
(
N
*
K
);
auto
BLOCKS
=
dim3
((
E
*
32
+
THREADS
-
1
)
/
THREADS
);
auto
out
=
at
::
empty
(
index
.
size
(
1
),
grad
.
options
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_ALL_TYPES
(
mat
.
scalar_type
(),
"spmm_val_bw_kernel"
,
[
&
]
{
auto
index_data
=
index
.
DATA_PTR
<
int64_t
>
();
auto
rowptr_data
=
rowptr
.
DATA_PTR
<
int64_t
>
();
auto
mat_data
=
mat
.
DATA_PTR
<
scalar_t
>
();
auto
grad_data
=
grad
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
spmm_val_bw_kernel
<
scalar_t
,
REDUCE
>
<<<
BLOCKS
,
THREADS
,
0
,
stream
>>>
(
index_data
,
rowptr_data
,
mat_data
,
grad_data
,
out_data
,
B
,
M
,
N
,
E
,
K
);
});
});
return
out
;
}
test/test_matmul.py
View file @
6a7f10e5
...
...
@@ -9,10 +9,7 @@ import torch_scatter
from
.utils
import
devices
,
grad_dtypes
devices
=
[
'cpu'
,
'cuda'
]
grad_dtypes
=
[
torch
.
float
]
reductions
=
[
'sum'
,
'mean'
,
'min'
,
'max'
]
reductions
=
[
'min'
,
'max'
]
@
pytest
.
mark
.
parametrize
(
'dtype,device,reduce'
,
...
...
torch_sparse/matmul.py
View file @
6a7f10e5
...
...
@@ -44,11 +44,11 @@ class SPMM(torch.autograd.Function):
if
ctx
.
needs_input_grad
[
5
]:
if
ctx
.
reduce
in
[
'sum'
,
'add'
]:
grad_value
=
spmm
(
grad_out
.
is_cuda
).
spmm_val_bw
(
rowptr
,
index
[
1
]
,
mat
,
grad_out
,
ctx
.
reduce
)
index
,
rowptr
,
mat
,
grad_out
,
ctx
.
reduce
)
if
ctx
.
reduce
==
'mean'
:
grad_value
=
spmm
(
grad_out
.
is_cuda
).
spmm_val_bw
(
rowptr
,
index
[
1
]
,
mat
,
grad_out
,
ctx
.
reduce
)
index
,
rowptr
,
mat
,
grad_out
,
ctx
.
reduce
)
elif
ctx
.
reduce
in
[
'min'
,
'max'
]:
col
=
index
[
1
][
arg_out_ind
.
flatten
()].
view_as
(
arg_out
)
...
...
@@ -108,5 +108,6 @@ def matmul(src, other, reduce='sum'):
elif
isinstance
(
other
,
src
.
__class__
):
assert
reduce
in
[
'sum'
,
'add'
]
raise
NotImplementedError
raise
ValueError
torch_sparse/tensor.py
View file @
6a7f10e5
...
...
@@ -12,6 +12,7 @@ from torch_sparse.index_select import index_select, index_select_nnz
from
torch_sparse.masked_select
import
masked_select
,
masked_select_nnz
import
torch_sparse.reduce
from
torch_sparse.diag
import
remove_diag
from
torch_sparse.matmul
import
matmul
class
SparseTensor
(
object
):
...
...
@@ -410,6 +411,9 @@ class SparseTensor(object):
return
out
def
__matmul__
(
a
,
b
):
return
matmul
(
a
,
b
,
reduce
=
'sum'
)
# String Reputation #######################################################
def
__repr__
(
self
):
...
...
@@ -446,6 +450,7 @@ SparseTensor.mean = torch_sparse.reduce.mean
SparseTensor
.
min
=
torch_sparse
.
reduce
.
min
SparseTensor
.
max
=
torch_sparse
.
reduce
.
max
SparseTensor
.
remove_diag
=
remove_diag
SparseTensor
.
matmul
=
matmul
# SparseTensor.add = add
# SparseTensor.add_nnz = add_nnz
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment