Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
d49dcbbd
Commit
d49dcbbd
authored
Jan 26, 2020
by
rusty1s
Browse files
diag and matmul fixes
parent
1a5fb80c
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
130 additions
and
136 deletions
+130
-136
cpu/diag.cpp
cpu/diag.cpp
+11
-10
cpu/spmm.cpp
cpu/spmm.cpp
+9
-17
cuda/diag.cpp
cuda/diag.cpp
+7
-5
cuda/diag_kernel.cu
cuda/diag_kernel.cu
+11
-11
cuda/spmm.cpp
cuda/spmm.cpp
+5
-4
cuda/spmm_kernel.cu
cuda/spmm_kernel.cu
+16
-23
cuda/spspmm_kernel.cu
cuda/spspmm_kernel.cu
+4
-8
torch_sparse/diag.py
torch_sparse/diag.py
+17
-19
torch_sparse/matmul.py
torch_sparse/matmul.py
+37
-36
torch_sparse/storage.py
torch_sparse/storage.py
+9
-0
torch_sparse/tensor.py
torch_sparse/tensor.py
+4
-3
No files found.
cpu/diag.cpp
View file @
d49dcbbd
...
...
@@ -4,23 +4,24 @@
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
at
::
Tensor
non_diag_mask
(
at
::
Tensor
index
,
int64_t
M
,
int64_t
N
,
int64_t
k
)
{
CHECK_CPU
(
index
);
int64_t
E
=
index
.
size
(
1
);
index
=
index
.
contiguous
();
auto
index_data
=
index
.
DATA_PTR
<
int64_t
>
();
at
::
Tensor
non_diag_mask
(
at
::
Tensor
row
,
at
::
Tensor
col
,
int64_t
M
,
int64_t
N
,
int64_t
k
)
{
CHECK_CPU
(
row
);
CHECK_CPU
(
col
);
int64_t
E
=
row
.
size
(
0
);
int64_t
num_diag
=
k
<
0
?
std
::
min
(
M
+
k
,
N
)
:
std
::
min
(
M
,
N
-
k
);
auto
mask
=
at
::
zeros
(
E
+
num_diag
,
index
.
options
().
dtype
(
at
::
kBool
));
auto
row_data
=
row
.
DATA_PTR
<
int64_t
>
();
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
auto
mask
=
at
::
zeros
(
E
+
num_diag
,
row
.
options
().
dtype
(
at
::
kBool
));
auto
mask_data
=
mask
.
DATA_PTR
<
bool
>
();
int64_t
r
,
c
;
if
(
k
<
0
)
{
for
(
int64_t
i
=
0
;
i
<
E
;
i
++
)
{
r
=
index
_data
[
i
],
c
=
index
_data
[
i
+
E
];
r
=
row
_data
[
i
],
c
=
col
_data
[
i
];
if
(
r
+
k
<
0
)
{
mask_data
[
i
]
=
true
;
}
else
if
(
r
+
k
>=
N
)
{
...
...
@@ -33,7 +34,7 @@ at::Tensor non_diag_mask(at::Tensor index, int64_t M, int64_t N, int64_t k) {
}
}
else
{
for
(
int64_t
i
=
0
;
i
<
E
;
i
++
)
{
r
=
index
_data
[
i
],
c
=
index
_data
[
i
+
E
];
r
=
row
_data
[
i
],
c
=
col
_data
[
i
];
if
(
r
+
k
>=
N
)
{
mask_data
[
i
+
num_diag
]
=
true
;
}
else
if
(
r
+
k
>
c
)
{
...
...
cpu/spmm.cpp
View file @
d49dcbbd
...
...
@@ -174,36 +174,28 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
return
std
::
make_tuple
(
out
,
arg_out
);
}
at
::
Tensor
spmm_val_bw
(
at
::
Tensor
index
,
at
::
Tensor
rowptr
,
at
::
Tensor
mat
,
at
::
Tensor
grad
,
std
::
string
reduce
)
{
CHECK_CPU
(
index
);
at
::
Tensor
spmm_val_bw
(
at
::
Tensor
row
,
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
Tensor
mat
,
at
::
Tensor
grad
,
std
::
string
reduce
)
{
CHECK_CPU
(
row
);
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
col
);
CHECK_CPU
(
mat
);
CHECK_CPU
(
grad
);
AT_ASSERTM
(
index
.
dim
()
==
2
,
"Input mismatch"
);
AT_ASSERTM
(
index
.
size
(
0
)
==
2
,
"Input mismatch"
);
AT_ASSERTM
(
rowptr
.
dim
()
==
1
,
"Input mismatch"
);
AT_ASSERTM
(
mat
.
dim
()
>=
2
,
"Input mismatch"
);
AT_ASSERTM
(
mat
.
dim
()
==
grad
.
dim
(),
"Input mismatch"
);
AT_ASSERTM
(
reduce2REDUCE
.
at
(
reduce
)
==
SUM
||
reduce2REDUCE
.
at
(
reduce
)
==
MEAN
,
"Reduce operation not supported"
);
index
=
index
.
contiguous
();
mat
=
mat
.
contiguous
();
grad
=
grad
.
contiguous
();
auto
M
=
grad
.
size
(
-
2
);
auto
N
=
mat
.
size
(
-
2
);
auto
E
=
index
.
size
(
1
);
auto
E
=
row
.
numel
(
);
auto
K
=
mat
.
size
(
-
1
);
auto
B
=
mat
.
numel
()
/
(
N
*
K
);
auto
out
=
at
::
zeros
(
index
.
size
(
1
),
grad
.
options
());
auto
out
=
at
::
zeros
(
row
.
numel
(
),
grad
.
options
());
auto
index
_data
=
index
.
DATA_PTR
<
int64_t
>
();
auto
row
_data
=
row
.
DATA_PTR
<
int64_t
>
();
auto
rowptr_data
=
rowptr
.
DATA_PTR
<
int64_t
>
();
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
AT_DISPATCH_ALL_TYPES
(
mat
.
scalar_type
(),
"spmm_val_bw"
,
[
&
]
{
auto
mat_data
=
mat
.
DATA_PTR
<
scalar_t
>
();
auto
grad_data
=
grad
.
DATA_PTR
<
scalar_t
>
();
...
...
@@ -214,7 +206,7 @@ at::Tensor spmm_val_bw(at::Tensor index, at::Tensor rowptr, at::Tensor mat,
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
for
(
int
b
=
0
;
b
<
B
;
b
++
)
{
for
(
int
e
=
0
;
e
<
E
;
e
++
)
{
row
=
index
_data
[
e
],
col
=
index
_data
[
E
+
e
],
val
=
(
scalar_t
)
0
;
row
=
row
_data
[
e
],
col
=
col
_data
[
e
],
val
=
(
scalar_t
)
0
;
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
val
+=
mat_data
[
b
*
N
*
K
+
col
*
K
+
k
]
*
grad_data
[
b
*
M
*
K
+
row
*
K
+
k
];
...
...
cuda/diag.cpp
View file @
d49dcbbd
...
...
@@ -2,12 +2,14 @@
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at
::
Tensor
non_diag_mask_cuda
(
at
::
Tensor
index
,
int64_t
M
,
int64_t
N
,
int64_t
k
);
at
::
Tensor
non_diag_mask_cuda
(
at
::
Tensor
row
,
at
::
Tensor
col
,
int64_t
M
,
int64_t
N
,
int64_t
k
);
at
::
Tensor
non_diag_mask
(
at
::
Tensor
index
,
int64_t
M
,
int64_t
N
,
int64_t
k
)
{
CHECK_CUDA
(
index
);
return
non_diag_mask_cuda
(
index
,
M
,
N
,
k
);
at
::
Tensor
non_diag_mask
(
at
::
Tensor
row
,
at
::
Tensor
col
,
int64_t
M
,
int64_t
N
,
int64_t
k
)
{
CHECK_CUDA
(
row
);
CHECK_CUDA
(
col
);
return
non_diag_mask_cuda
(
row
,
col
,
M
,
N
,
k
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
...
...
cuda/diag_kernel.cu
View file @
d49dcbbd
...
...
@@ -5,14 +5,15 @@
#define THREADS 1024
__global__
void
non_diag_mask_kernel
(
const
int64_t
*
index_data
,
bool
*
out_data
,
__global__
void
non_diag_mask_kernel
(
const
int64_t
*
row_data
,
const
int64_t
*
col_data
,
bool
*
out_data
,
int64_t
N
,
int64_t
k
,
int64_t
num_diag
,
int64_t
numel
)
{
int64_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
thread_idx
<
numel
)
{
int64_t
r
=
index
_data
[
thread_idx
],
c
=
index
_data
[
thread_idx
+
numel
];
int64_t
r
=
row
_data
[
thread_idx
],
c
=
col
_data
[
thread_idx
];
if
(
k
<
0
)
{
if
(
r
+
k
<
0
)
{
...
...
@@ -37,21 +38,20 @@ __global__ void non_diag_mask_kernel(const int64_t *index_data, bool *out_data,
}
}
at
::
Tensor
non_diag_mask_cuda
(
at
::
Tensor
index
,
int64_t
M
,
int64_t
N
,
int64_t
k
)
{
int64_t
E
=
index
.
size
(
1
);
index
=
index
.
contiguous
();
auto
index_data
=
index
.
DATA_PTR
<
int64_t
>
();
at
::
Tensor
non_diag_mask_cuda
(
at
::
Tensor
row
,
at
::
Tensor
col
,
int64_t
M
,
int64_t
N
,
int64_t
k
)
{
int64_t
E
=
row
.
size
(
0
);
int64_t
num_diag
=
k
<
0
?
std
::
min
(
M
+
k
,
N
)
:
std
::
min
(
M
,
N
-
k
);
auto
mask
=
at
::
zeros
(
E
+
num_diag
,
index
.
options
().
dtype
(
at
::
kBool
));
auto
row_data
=
row
.
DATA_PTR
<
int64_t
>
();
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
auto
mask
=
at
::
zeros
(
E
+
num_diag
,
row
.
options
().
dtype
(
at
::
kBool
));
auto
mask_data
=
mask
.
DATA_PTR
<
bool
>
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
non_diag_mask_kernel
<<<
(
E
+
THREADS
-
1
)
/
THREADS
,
THREADS
,
0
,
stream
>>>
(
index
_data
,
mask_data
,
N
,
k
,
num_diag
,
E
);
row_data
,
col
_data
,
mask_data
,
N
,
k
,
num_diag
,
E
);
return
mask
;
}
cuda/spmm.cpp
View file @
d49dcbbd
...
...
@@ -20,13 +20,14 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
return
spmm_cuda
(
rowptr
,
col
,
value_opt
,
mat
,
reduce
);
}
at
::
Tensor
spmm_val_bw
(
at
::
Tensor
index
,
at
::
Tensor
rowptr
,
at
::
Tensor
mat
,
at
::
Tensor
grad
,
std
::
string
reduce
)
{
CHECK_CUDA
(
index
);
at
::
Tensor
spmm_val_bw
(
at
::
Tensor
row
,
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
Tensor
mat
,
at
::
Tensor
grad
,
std
::
string
reduce
)
{
CHECK_CUDA
(
row
);
CHECK_CUDA
(
rowptr
);
CHECK_CUDA
(
col
);
CHECK_CUDA
(
mat
);
CHECK_CUDA
(
grad
);
return
spmm_val_bw_cuda
(
index
,
rowptr
,
mat
,
grad
,
reduce
);
return
spmm_val_bw_cuda
(
row
,
rowptr
,
col
,
mat
,
grad
,
reduce
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
...
...
cuda/spmm_kernel.cu
View file @
d49dcbbd
...
...
@@ -210,17 +210,18 @@ spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
template
<
typename
scalar_t
,
ReductionType
REDUCE
>
__global__
void
spmm_val_bw_kernel
(
const
int64_t
*
index_data
,
const
int64_t
*
rowptr_data
,
const
scalar_t
*
mat_data
,
const
scalar_t
*
grad_data
,
scalar_t
*
out_data
,
int
B
,
int
M
,
int
N
,
int
E
,
int
K
)
{
spmm_val_bw_kernel
(
const
int64_t
*
row_data
,
const
int64_t
*
rowptr_data
,
const
int64_t
*
col_data
,
const
scalar_t
*
mat_data
,
const
scalar_t
*
grad_data
,
scalar_t
*
out_data
,
int
B
,
int
M
,
int
N
,
int
E
,
int
K
)
{
int
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
index_idx
=
(
thread_idx
>>
5
);
// thread_idx / 32
int
lane_idx
=
thread_idx
&
(
32
-
1
);
// thread_idx % 32
if
(
index_idx
<
E
)
{
int
row
=
__ldg
(
index
_data
+
index_idx
);
int
col
=
__ldg
(
index
_data
+
E
+
index_idx
);
int
row
=
__ldg
(
row
_data
+
index_idx
);
int
col
=
__ldg
(
col
_data
+
index_idx
);
scalar_t
val
=
(
scalar_t
)
0
;
for
(
int
b
=
0
;
b
<
B
;
b
++
)
{
...
...
@@ -246,43 +247,35 @@ spmm_val_bw_kernel(const int64_t *index_data, const int64_t *rowptr_data,
}
}
at
::
Tensor
spmm_val_bw_cuda
(
at
::
Tensor
index
,
at
::
Tensor
rowptr
,
at
::
Tensor
mat
,
at
::
Tensor
grad
,
std
::
string
reduce
)
{
at
::
Tensor
spmm_val_bw_cuda
(
at
::
Tensor
row
,
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
Tensor
mat
,
at
::
Tensor
grad
,
std
::
string
reduce
)
{
AT_ASSERTM
(
index
.
dim
()
==
2
,
"Input mismatch"
);
AT_ASSERTM
(
index
.
size
(
0
)
==
2
,
"Input mismatch"
);
AT_ASSERTM
(
rowptr
.
dim
()
==
1
,
"Input mismatch"
);
AT_ASSERTM
(
mat
.
dim
()
>=
2
,
"Input mismatch"
);
AT_ASSERTM
(
mat
.
dim
()
==
grad
.
dim
(),
"Input mismatch"
);
AT_ASSERTM
(
reduce2REDUCE
.
at
(
reduce
)
==
SUM
||
reduce2REDUCE
.
at
(
reduce
)
==
MEAN
,
"Reduce operation not supported"
);
index
=
index
.
contiguous
();
mat
=
mat
.
contiguous
();
grad
=
grad
.
contiguous
();
auto
M
=
grad
.
size
(
-
2
);
auto
N
=
mat
.
size
(
-
2
);
auto
E
=
index
.
size
(
1
);
auto
E
=
row
.
numel
(
);
auto
K
=
mat
.
size
(
-
1
);
auto
B
=
mat
.
numel
()
/
(
N
*
K
);
auto
BLOCKS
=
dim3
((
E
*
32
+
THREADS
-
1
)
/
THREADS
);
auto
out
=
at
::
empty
(
index
.
size
(
1
),
grad
.
options
());
auto
out
=
at
::
zeros
(
row
.
numel
(
),
grad
.
options
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_ALL_TYPES
(
mat
.
scalar_type
(),
"spmm_val_bw_kernel"
,
[
&
]
{
auto
index
_data
=
index
.
DATA_PTR
<
int64_t
>
();
auto
row
_data
=
row
.
DATA_PTR
<
int64_t
>
();
auto
rowptr_data
=
rowptr
.
DATA_PTR
<
int64_t
>
();
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
auto
mat_data
=
mat
.
DATA_PTR
<
scalar_t
>
();
auto
grad_data
=
grad
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
spmm_val_bw_kernel
<
scalar_t
,
REDUCE
>
<<<
BLOCKS
,
THREADS
,
0
,
stream
>>>
(
index_data
,
rowptr_data
,
mat_data
,
grad_data
,
out_data
,
B
,
M
,
N
,
E
,
K
);
spmm_val_bw_kernel
<
scalar_t
,
REDUCE
>
<<<
BLOCKS
,
THREADS
,
0
,
stream
>>>
(
row_data
,
rowptr_data
,
col_data
,
mat_data
,
grad_data
,
out_data
,
B
,
M
,
N
,
E
,
K
);
});
});
...
...
cuda/spspmm_kernel.cu
View file @
d49dcbbd
...
...
@@ -121,14 +121,10 @@ spspmm_cuda(at::Tensor rowptrA, at::Tensor colA,
descr
,
valueC_data
,
rowptrC_data
,
colC_data
,
info
,
buffer
);
});
auto
rowC
=
at
::
empty_like
(
colC
);
auto
rowC_data
=
rowC
.
DATA_PTR
<
int
>
();
cusparseXcsr2coo
(
handle
,
rowptrC_data
,
nnzC
,
M
,
rowC_data
,
CUSPARSE_INDEX_BASE_ZERO
);
cusparseDestroyCsrgemm2Info
(
info
);
auto
indexC
=
at
::
stack
({
rowC
.
toType
(
at
::
kLong
),
colC
.
toType
(
at
::
kLong
)},
0
);
return
std
::
make_tuple
(
indexC
,
rowptrC
.
toType
(
at
::
kLong
),
valueC
);
rowptrC
=
rowptrC
.
toType
(
at
::
kLong
);
colC
=
col
.
toType
(
at
::
kLong
);
return
std
::
make_tuple
(
rowptrC
,
colC
,
valueC
);
}
// #define THREADS 1024
...
...
torch_sparse/diag.py
View file @
d49dcbbd
...
...
@@ -9,12 +9,9 @@ except ImportError:
def
remove_diag
(
src
,
k
=
0
):
index
,
value
=
src
.
coo
()
row
,
col
=
index
row
,
col
,
value
=
src
.
coo
()
inv_mask
=
row
!=
col
if
k
==
0
else
row
!=
(
col
-
k
)
index
=
index
[:,
inv_mask
]
row
,
col
=
row
[
inv_mask
],
col
[
inv_mask
]
if
src
.
has_value
():
value
=
value
[
inv_mask
]
...
...
@@ -32,7 +29,7 @@ def remove_diag(src, k=0):
colcount
=
src
.
storage
.
colcount
.
clone
()
colcount
[
col
[
mask
]]
-=
1
storage
=
src
.
storage
.
__class__
(
index
,
value
,
storage
=
src
.
storage
.
__class__
(
row
=
row
,
col
=
col
,
value
=
value
,
sparse_size
=
src
.
sparse_size
(),
rowcount
=
rowcount
,
colcount
=
colcount
,
is_sorted
=
True
)
...
...
@@ -45,26 +42,26 @@ def set_diag(src, values=None, k=0):
src
=
src
.
remove_diag
(
k
=
0
)
index
,
value
=
src
.
coo
()
row
,
col
,
value
=
src
.
coo
()
func
=
diag_cuda
if
index
.
is_cuda
else
diag_cpu
mask
=
func
.
non_diag_mask
(
index
,
src
.
size
(
0
),
src
.
size
(
1
),
k
)
func
=
diag_cuda
if
row
.
is_cuda
else
diag_cpu
mask
=
func
.
non_diag_mask
(
row
,
col
,
src
.
size
(
0
),
src
.
size
(
1
),
k
)
inv_mask
=
~
mask
new_index
=
index
.
new_empty
((
2
,
mask
.
size
(
0
))
)
new_index
[:,
mask
]
=
index
start
,
num_diag
=
-
k
if
k
<
0
else
0
,
mask
.
numel
()
-
row
.
numel
(
)
diag
=
torch
.
arange
(
start
,
start
+
num_diag
,
device
=
src
.
device
)
num_diag
=
mask
.
numel
()
-
index
.
size
(
1
)
start
=
-
k
if
k
<
0
else
0
new_row
=
row
.
new_empty
(
mask
.
size
(
0
))
new_row
[
mask
]
=
row
new_row
[
inv_mask
]
=
diag
diag_row
=
torch
.
arange
(
start
,
start
+
num_diag
,
device
=
src
.
device
)
new_index
[
0
,
inv_mask
]
=
diag_row
diag_col
=
diag_row
.
add_
(
k
)
new_index
[
1
,
inv_mask
]
=
diag_col
new_col
=
col
.
new_empty
(
mask
.
size
(
0
))
new_col
[
mask
]
=
row
new_col
[
inv_mask
]
=
diag
.
add_
(
k
)
new_value
=
None
if
src
.
has_value
():
new_value
=
torch
.
new_empty
((
mask
.
size
(
0
),
)
+
mask
.
size
()[
1
:])
new_value
=
torch
.
new_empty
((
mask
.
size
(
0
),
)
+
value
.
size
()[
1
:])
new_value
[
mask
]
=
value
new_value
[
inv_mask
]
=
values
if
values
is
not
None
else
1
...
...
@@ -78,8 +75,9 @@ def set_diag(src, values=None, k=0):
colcount
=
src
.
storage
.
colcount
.
clone
()
colcount
[
start
+
k
:
start
+
num_diag
+
k
]
+=
1
storage
=
src
.
storage
.
__class__
(
new_index
,
new_value
,
storage
=
src
.
storage
.
__class__
(
row
=
new_row
,
col
=
new_col
,
value
=
new_value
,
sparse_size
=
src
.
sparse_size
(),
rowcount
=
rowcount
,
colcount
=
colcount
,
is_sorted
=
True
)
return
src
.
__class__
.
from_storage
(
storage
)
torch_sparse/matmul.py
View file @
d49dcbbd
...
...
@@ -20,14 +20,13 @@ def spmm(is_cuda):
class
SPMM
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
index
,
rowcoun
t
,
row
ptr
,
colptr
,
csr2csc
,
value
,
mat
,
def
forward
(
ctx
,
row
,
rowptr
,
col
,
value
,
ma
t
,
row
count
,
colptr
,
csr2csc
,
reduce
):
out
,
arg_out
=
spmm
(
mat
.
is_cuda
).
spmm
(
rowptr
,
index
[
1
],
value
,
mat
,
reduce
)
out
,
arg_out
=
spmm
(
mat
.
is_cuda
).
spmm
(
rowptr
,
col
,
value
,
mat
,
reduce
)
ctx
.
reduce
=
reduce
ctx
.
save_for_backward
(
index
,
rowcount
,
rowptr
,
colptr
,
csr2csc
,
value
,
mat
,
arg_out
)
ctx
.
save_for_backward
(
row
,
rowptr
,
col
,
value
,
mat
,
rowcount
,
colptr
,
csr2csc
,
arg_out
)
if
reduce
==
'min'
or
reduce
==
'max'
:
ctx
.
mark_non_differentiable
(
arg_out
)
...
...
@@ -37,27 +36,27 @@ class SPMM(torch.autograd.Function):
@
staticmethod
def
backward
(
ctx
,
grad_out
,
*
args
):
data
=
ctx
.
saved_tensors
index
,
rowcount
,
rowptr
,
colptr
,
csr2csc
,
value
,
mat
,
arg_out
=
data
(
row
,
rowptr
,
col
,
value
,
mat
,
rowcount
,
colptr
,
csr2csc
,
arg_out
)
=
ctx
.
saved_tensors
invalid_arg_mask
=
arg_out_ind
=
None
if
ctx
.
reduce
in
[
'min'
,
'max'
]
and
(
ctx
.
needs_input_grad
[
5
]
or
ctx
.
needs_input_grad
[
6
]):
invalid_arg_mask
=
arg_out
==
index
.
size
(
1
)
invalid_arg_mask
=
arg_out
==
row
.
size
(
0
)
arg_out_ind
=
arg_out
.
masked_fill
(
invalid_arg_mask
,
-
1
)
grad_value
=
None
if
ctx
.
needs_input_grad
[
5
]:
if
ctx
.
needs_input_grad
[
3
]:
if
ctx
.
reduce
in
[
'sum'
,
'add'
]:
grad_value
=
spmm
(
grad_out
.
is_cuda
).
spmm_val_bw
(
index
,
rowptr
,
mat
,
grad_out
,
ctx
.
reduce
)
row
,
rowptr
,
col
,
mat
,
grad_out
,
ctx
.
reduce
)
if
ctx
.
reduce
==
'mean'
:
grad_value
=
spmm
(
grad_out
.
is_cuda
).
spmm_val_bw
(
index
,
rowptr
,
mat
,
grad_out
,
ctx
.
reduce
)
row
,
rowptr
,
col
,
mat
,
grad_out
,
ctx
.
reduce
)
elif
ctx
.
reduce
in
[
'min'
,
'max'
]:
col
=
index
[
1
]
[
arg_out_ind
.
flatten
()].
view_as
(
arg_out
)
col
=
col
[
arg_out_ind
.
flatten
()].
view_as
(
arg_out
)
out
=
mat
.
gather
(
-
2
,
col
).
mul_
(
grad_out
)
out
.
masked_fill_
(
invalid_arg_mask
,
0
)
grad_value
=
scatter_add
(
out
.
flatten
(),
arg_out
.
flatten
(),
...
...
@@ -65,16 +64,16 @@ class SPMM(torch.autograd.Function):
grad_value
=
grad_value
[:
-
1
]
grad_mat
=
None
if
ctx
.
needs_input_grad
[
6
]:
if
ctx
.
needs_input_grad
[
4
]:
if
ctx
.
reduce
in
[
'sum'
,
'add'
]:
value
=
value
[
csr2csc
]
if
value
is
not
None
else
value
grad_mat
,
_
=
spmm
(
grad_out
.
is_cuda
).
spmm
(
colptr
,
index
[
0
]
[
csr2csc
],
value
,
grad_out
,
'sum'
)
colptr
,
row
[
csr2csc
],
value
,
grad_out
,
'sum'
)
elif
ctx
.
reduce
==
'mean'
:
count
=
rowcount
[
index
[
0
]
].
to
(
mat
.
dtype
).
clamp_
(
min
=
1
)
count
=
rowcount
[
row
].
to
(
mat
.
dtype
).
clamp_
(
min
=
1
)
value
=
count
.
pow_
(
-
1
)
if
value
is
None
else
value
/
count
row
=
index
[
0
]
[
csr2csc
]
row
=
row
[
csr2csc
]
value
=
value
[
csr2csc
]
if
value
is
not
None
else
value
grad_mat
,
_
=
spmm
(
grad_out
.
is_cuda
).
spmm
(
colptr
,
row
,
value
,
grad_out
,
'sum'
)
...
...
@@ -86,19 +85,20 @@ class SPMM(torch.autograd.Function):
else
:
value
=
grad_out
value
.
masked_fill_
(
invalid_arg_mask
,
0
)
col
=
index
[
1
]
[
arg_out_ind
.
flatten
()].
view_as
(
arg_out
)
col
=
col
[
arg_out_ind
.
flatten
()].
view_as
(
arg_out
)
grad_mat
=
scatter_add
(
value
,
col
,
dim
=-
2
,
dim_size
=
mat
.
size
(
-
2
))
return
None
,
None
,
None
,
None
,
None
,
grad_value
,
grad_mat
,
None
return
None
,
None
,
None
,
grad_value
,
grad_mat
,
None
,
None
,
None
,
None
class
SPSPMM
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
rowptrA
,
colA
,
valueA
,
rowptrB
,
colB
,
valueB
,
M
,
N
,
K
):
if
rowptrA
.
is_cuda
:
indexC
,
rowptrC
,
valueC
=
spspmm_cuda
.
spspmm
(
rowptrA
,
colA
,
valueA
,
rowptrB
,
colB
,
valueB
,
M
,
N
,
K
)
rowptrC
,
colC
,
valueC
=
spspmm_cuda
.
spspmm
(
rowptrA
,
colA
,
valueA
,
rowptrB
,
colB
,
valueB
,
M
,
N
,
K
)
else
:
dtype
=
None
if
valueA
is
not
None
:
...
...
@@ -116,21 +116,18 @@ class SPSPMM(torch.autograd.Function):
C
=
A
@
B
valueC
=
torch
.
from_numpy
(
C
.
data
).
to
(
dtype
)
if
dtype
is
not
None
else
None
rowptrC
=
torch
.
from_numpy
(
C
.
indptr
).
to
(
torch
.
int64
)
C
=
C
.
tocoo
()
rowC
=
torch
.
from_numpy
(
C
.
row
).
to
(
torch
.
int64
)
colC
=
torch
.
from_numpy
(
C
.
col
).
to
(
torch
.
int64
)
indexC
=
torch
.
stack
([
rowC
,
colC
],
dim
=
0
)
colC
=
torch
.
from_numpy
(
C
.
indices
).
to
(
torch
.
int64
)
valueC
=
torch
.
from_numpy
(
C
.
data
)
valueC
=
valueC
.
to
(
dtype
)
if
dtype
is
not
None
else
valueC
ctx
.
mark_non_differentiable
(
indexC
,
rowptrC
)
ctx
.
mark_non_differentiable
(
rowptr
C
,
col
C
)
# We cannot return `NoneType` in torch.autograd :(
if
valueC
is
None
:
return
indexC
,
rowptrC
return
rowptrC
,
colC
else
:
return
indexC
,
rowptrC
,
valueC
return
rowptr
C
,
col
C
,
valueC
@
staticmethod
def
backward
(
ctx
,
grad_indexC
,
grad_rowptrC
,
*
args
):
...
...
@@ -152,7 +149,12 @@ def matmul(src, other, reduce='sum'):
# Sparse-Dense Matrix Multiplication.
if
torch
.
is_tensor
(
other
):
assert
reduce
in
[
'sum'
,
'add'
,
'mean'
,
'min'
,
'max'
]
(
index
,
value
),
rowptr
=
src
.
coo
(),
src
.
storage
.
rowptr
rowptr
,
col
,
value
=
src
.
csr
()
row
=
None
if
reduce
in
[
'sum'
,
'add'
]
and
(
src
.
requires_grad
or
other
.
reuqires_grad
):
row
=
src
.
storage
.
row
rowcount
=
None
if
other
.
requires_grad
and
reduce
in
[
'mean'
]:
...
...
@@ -162,8 +164,8 @@ def matmul(src, other, reduce='sum'):
if
other
.
requires_grad
and
reduce
in
[
'sum'
,
'add'
,
'mean'
]:
csr2csc
,
colptr
=
src
.
storage
.
csr2csc
,
src
.
storage
.
colptr
return
SPMM
.
apply
(
index
,
rowcount
,
rowptr
,
colptr
,
csr2csc
,
value
,
other
,
reduce
)
return
SPMM
.
apply
(
row
,
rowptr
,
col
,
value
,
other
,
rowcount
,
colptr
,
csr2csc
,
reduce
)
# Sparse-Sparse Matrix Multiplication.
elif
isinstance
(
other
,
src
.
__class__
):
...
...
@@ -171,10 +173,9 @@ def matmul(src, other, reduce='sum'):
assert
src
.
dim
()
==
2
and
other
.
dim
()
==
2
data
=
SPSPMM
.
apply
(
*
src
.
csr
(),
*
other
.
csr
(),
src
.
size
(
0
),
src
.
size
(
1
),
other
.
size
(
1
))
data
=
data
if
len
(
data
)
==
3
else
data
+
(
None
,
)
(
rowptr
,
col
),
value
=
data
[:
2
],
data
[
2
]
if
len
(
data
)
==
3
else
None
sparse_size
=
torch
.
Size
([
src
.
size
(
0
),
other
.
size
(
1
)])
out
=
src
.
__class__
(
data
[
0
],
data
[
2
],
sparse_size
,
is_sorted
=
True
)
out
.
storage
.
_rowptr
=
data
[
1
]
return
out
return
src
.
__class__
(
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_size
=
sparse_size
,
is_sorted
=
True
)
raise
ValueError
torch_sparse/storage.py
View file @
d49dcbbd
...
...
@@ -78,6 +78,7 @@ class SparseStorage(object):
assert
col
is
not
None
assert
col
.
dtype
==
torch
.
long
assert
col
.
dim
()
==
1
col
=
col
.
contiguous
()
if
sparse_size
is
None
:
M
=
rowptr
.
numel
()
-
1
if
row
is
None
else
row
.
max
().
item
()
+
1
...
...
@@ -89,46 +90,54 @@ class SparseStorage(object):
assert
row
.
device
==
col
.
device
assert
row
.
dim
()
==
1
assert
row
.
numel
()
==
col
.
numel
()
row
=
row
.
contiguous
()
if
rowptr
is
not
None
:
assert
rowptr
.
dtype
==
torch
.
long
assert
rowptr
.
device
==
col
.
device
assert
rowptr
.
dim
()
==
1
assert
rowptr
.
numel
()
-
1
==
sparse_size
[
0
]
rowptr
=
rowptr
.
contiguous
()
if
value
is
not
None
:
assert
value
.
device
==
col
.
device
assert
value
.
size
(
0
)
==
col
.
size
(
0
)
value
=
value
.
contiguous
()
if
rowcount
is
not
None
:
assert
rowcount
.
dtype
==
torch
.
long
assert
rowcount
.
device
==
col
.
device
assert
rowcount
.
dim
()
==
1
assert
rowcount
.
numel
()
==
sparse_size
[
0
]
rowcount
=
rowcount
.
contiguous
()
if
colptr
is
not
None
:
assert
colptr
.
dtype
==
torch
.
long
assert
colptr
.
device
==
col
.
device
assert
colptr
.
dim
()
==
1
assert
colptr
.
numel
()
-
1
==
sparse_size
[
1
]
colptr
=
colptr
.
contiguous
()
if
colcount
is
not
None
:
assert
colcount
.
dtype
==
torch
.
long
assert
colcount
.
device
==
col
.
device
assert
colcount
.
dim
()
==
1
assert
colcount
.
numel
()
==
sparse_size
[
1
]
colcount
=
colcount
.
contiguous
()
if
csr2csc
is
not
None
:
assert
csr2csc
.
dtype
==
torch
.
long
assert
csr2csc
.
device
==
col
.
device
assert
csr2csc
.
dim
()
==
1
assert
csr2csc
.
numel
()
==
col
.
size
(
0
)
csr2csc
=
csr2csc
.
contiguous
()
if
csc2csr
is
not
None
:
assert
csc2csr
.
dtype
==
torch
.
long
assert
csc2csr
.
device
==
col
.
device
assert
csc2csr
.
dim
()
==
1
assert
csc2csr
.
numel
()
==
col
.
size
(
0
)
csc2csr
=
csc2csr
.
contiguous
()
self
.
_row
=
row
self
.
_rowptr
=
rowptr
...
...
torch_sparse/tensor.py
View file @
d49dcbbd
...
...
@@ -11,7 +11,7 @@ from torch_sparse.select import select
from
torch_sparse.index_select
import
index_select
,
index_select_nnz
from
torch_sparse.masked_select
import
masked_select
,
masked_select_nnz
import
torch_sparse.reduce
from
torch_sparse.diag
import
remove_diag
from
torch_sparse.diag
import
remove_diag
,
set_diag
from
torch_sparse.matmul
import
matmul
from
torch_sparse.add
import
add
,
add_
,
add_nnz
,
add_nnz_
...
...
@@ -482,8 +482,9 @@ SparseTensor.sum = torch_sparse.reduce.sum
SparseTensor
.
mean
=
torch_sparse
.
reduce
.
mean
SparseTensor
.
min
=
torch_sparse
.
reduce
.
min
SparseTensor
.
max
=
torch_sparse
.
reduce
.
max
SparseTensor
.
remove_diag
=
remove_diag
SparseTensor
.
matmul
=
matmul
SparseTensor
.
remove_diag
=
remove_diag
#TODO
SparseTensor
.
set_diag
=
set_diag
#TODO
SparseTensor
.
matmul
=
matmul
# TODO
SparseTensor
.
add
=
add
SparseTensor
.
add_
=
add_
SparseTensor
.
add_nnz
=
add_nnz
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment