Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
28cb8de4
Commit
28cb8de4
authored
Feb 02, 2020
by
rusty1s
Browse files
fix rowcount/colcount and added spmm mean
parent
d613c5c0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
140 additions
and
47 deletions
+140
-47
csrc/spmm.cpp
csrc/spmm.cpp
+102
-44
test/test_matmul.py
test/test_matmul.py
+1
-1
torch_sparse/matmul.py
torch_sparse/matmul.py
+35
-0
torch_sparse/storage.py
torch_sparse/storage.py
+2
-2
No files found.
csrc/spmm.cpp
View file @
28cb8de4
...
...
@@ -24,7 +24,7 @@ spmm_fw(torch::Tensor rowptr, torch::Tensor col,
torch
::
Tensor
spmm_value_bw
(
torch
::
Tensor
row
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
mat
,
torch
::
Tensor
grad
,
std
::
string
reduce
)
{
if
(
row
ptr
.
device
().
is_cuda
())
{
if
(
row
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
return
spmm_value_bw_cuda
(
row
,
rowptr
,
col
,
mat
,
grad
,
reduce
);
#else
...
...
@@ -42,25 +42,21 @@ using torch::autograd::variable_list;
class
SPMMSum
:
public
torch
::
autograd
::
Function
<
SPMMSum
>
{
public:
static
variable_list
forward
(
AutogradContext
*
ctx
,
torch
::
optional
<
Variable
>
opt
ional
_row
,
torch
::
optional
<
Variable
>
opt_row
,
Variable
rowptr
,
Variable
col
,
Variable
value
,
torch
::
optional
<
Variable
>
opt
ional
_colptr
,
torch
::
optional
<
Variable
>
opt
ional
_csr2csc
,
torch
::
optional
<
Variable
>
opt_colptr
,
torch
::
optional
<
Variable
>
opt_csr2csc
,
Variable
mat
)
{
torch
::
Tensor
row
;
if
(
optional_row
.
has_value
())
row
=
optional_row
.
value
();
torch
::
optional
<
torch
::
Tensor
>
optional_value
=
torch
::
nullopt
;
auto
row
=
opt_row
.
has_value
()
?
opt_row
.
value
()
:
torch
::
Tensor
();
auto
colptr
=
opt_colptr
.
has_value
()
?
opt_colptr
.
value
()
:
torch
::
Tensor
();
auto
csr2csc
=
opt_csr2csc
.
has_value
()
?
opt_csr2csc
.
value
()
:
torch
::
Tensor
();
torch
::
optional
<
torch
::
Tensor
>
opt_value
=
torch
::
nullopt
;
if
(
value
.
numel
()
>
0
)
optional_value
=
value
;
torch
::
Tensor
colptr
;
if
(
optional_colptr
.
has_value
())
colptr
=
optional_colptr
.
value
();
torch
::
Tensor
csr2csc
;
if
(
optional_csr2csc
.
has_value
())
csr2csc
=
optional_csr2csc
.
value
();
auto
out
=
std
::
get
<
0
>
(
spmm_fw
(
rowptr
,
col
,
optional_value
,
mat
,
"sum"
));
opt_value
=
value
;
auto
out
=
std
::
get
<
0
>
(
spmm_fw
(
rowptr
,
col
,
opt_value
,
mat
,
"sum"
));
ctx
->
save_for_backward
({
row
,
rowptr
,
col
,
value
,
colptr
,
csr2csc
,
mat
});
return
{
out
};
}
...
...
@@ -68,30 +64,23 @@ public:
static
variable_list
backward
(
AutogradContext
*
ctx
,
variable_list
grad_outs
)
{
auto
grad_out
=
grad_outs
[
0
];
auto
saved
=
ctx
->
get_saved_variables
();
auto
row
=
saved
[
0
];
auto
rowptr
=
saved
[
1
];
auto
col
=
saved
[
2
];
auto
value
=
saved
[
3
];
torch
::
optional
<
torch
::
Tensor
>
optional_value
=
torch
::
nullopt
;
if
(
value
.
numel
()
>
0
)
optional_value
=
value
;
auto
colptr
=
saved
[
4
];
auto
csr2csc
=
saved
[
5
];
auto
mat
=
saved
[
6
];
auto
row
=
saved
[
0
],
rowptr
=
saved
[
1
],
col
=
saved
[
2
],
value
=
saved
[
3
],
colptr
=
saved
[
4
],
csr2csc
=
saved
[
5
],
mat
=
saved
[
6
];
auto
grad_value
=
Variable
();
if
(
optional_value
.
has_value
()
&&
if
(
value
.
numel
()
>
0
&&
torch
::
autograd
::
any_variable_requires_grad
({
value
}))
{
grad_value
=
spmm_value_bw
(
row
,
rowptr
,
col
,
mat
,
grad_out
,
"sum"
);
}
auto
grad_mat
=
Variable
();
if
(
torch
::
autograd
::
any_variable_requires_grad
({
mat
}))
{
if
(
optional_value
.
has_value
())
optional_value
=
optional_value
.
value
().
index_select
(
0
,
csr2csc
);
grad_mat
=
torch
::
zeros_like
(
mat
);
torch
::
optional
<
torch
::
Tensor
>
opt_value
=
torch
::
nullopt
;
if
(
value
.
numel
()
>
0
)
opt_value
=
value
.
index_select
(
0
,
csr2csc
);
grad_mat
=
std
::
get
<
0
>
(
spmm_fw
(
colptr
,
row
.
index_select
(
0
,
csr2csc
),
opt
ional
_value
,
grad_out
,
"sum"
));
opt_value
,
grad_out
,
"sum"
));
}
return
{
Variable
(),
Variable
(),
Variable
(),
grad_value
,
...
...
@@ -99,20 +88,89 @@ public:
}
};
torch
::
Tensor
spmm_sum
(
torch
::
optional
<
torch
::
Tensor
>
optional_row
,
class
SPMMMean
:
public
torch
::
autograd
::
Function
<
SPMMMean
>
{
public:
static
variable_list
forward
(
AutogradContext
*
ctx
,
torch
::
optional
<
Variable
>
opt_row
,
Variable
rowptr
,
Variable
col
,
Variable
value
,
torch
::
optional
<
Variable
>
opt_rowcount
,
torch
::
optional
<
Variable
>
opt_colptr
,
torch
::
optional
<
Variable
>
opt_csr2csc
,
Variable
mat
)
{
auto
row
=
opt_row
.
has_value
()
?
opt_row
.
value
()
:
torch
::
Tensor
();
auto
rowcount
=
opt_rowcount
.
has_value
()
?
opt_rowcount
.
value
()
:
torch
::
Tensor
();
auto
colptr
=
opt_colptr
.
has_value
()
?
opt_colptr
.
value
()
:
torch
::
Tensor
();
auto
csr2csc
=
opt_csr2csc
.
has_value
()
?
opt_csr2csc
.
value
()
:
torch
::
Tensor
();
torch
::
optional
<
torch
::
Tensor
>
opt_value
=
torch
::
nullopt
;
if
(
value
.
numel
()
>
0
)
opt_value
=
value
;
auto
out
=
std
::
get
<
0
>
(
spmm_fw
(
rowptr
,
col
,
opt_value
,
mat
,
"mean"
));
ctx
->
save_for_backward
(
{
row
,
rowptr
,
col
,
value
,
rowcount
,
colptr
,
csr2csc
,
mat
});
return
{
out
};
}
static
variable_list
backward
(
AutogradContext
*
ctx
,
variable_list
grad_outs
)
{
auto
grad_out
=
grad_outs
[
0
];
auto
saved
=
ctx
->
get_saved_variables
();
auto
row
=
saved
[
0
],
rowptr
=
saved
[
1
],
col
=
saved
[
2
],
value
=
saved
[
3
],
rowcount
=
saved
[
4
],
colptr
=
saved
[
5
],
csr2csc
=
saved
[
6
],
mat
=
saved
[
7
];
auto
grad_value
=
Variable
();
if
(
value
.
numel
()
>
0
&&
torch
::
autograd
::
any_variable_requires_grad
({
value
}))
{
grad_value
=
spmm_value_bw
(
row
,
rowptr
,
col
,
mat
,
grad_out
,
"mean"
);
}
auto
grad_mat
=
Variable
();
if
(
torch
::
autograd
::
any_variable_requires_grad
({
mat
}))
{
row
=
row
.
index_select
(
0
,
csr2csc
);
rowcount
=
rowcount
.
toType
(
mat
.
scalar_type
()).
index_select
(
0
,
row
);
rowcount
.
clamp_
(
1
);
if
(
value
.
numel
()
>
0
)
rowcount
=
value
.
index_select
(
0
,
csr2csc
).
div
(
rowcount
);
else
rowcount
.
pow_
(
-
1
);
grad_mat
=
std
::
get
<
0
>
(
spmm_fw
(
colptr
,
row
,
rowcount
,
grad_out
,
"sum"
));
}
return
{
Variable
(),
Variable
(),
Variable
(),
grad_value
,
Variable
(),
Variable
(),
Variable
(),
grad_mat
};
}
};
torch
::
Tensor
spmm_sum
(
torch
::
optional
<
torch
::
Tensor
>
opt_row
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
opt
ional
_value
,
torch
::
optional
<
torch
::
Tensor
>
opt
ional
_colptr
,
torch
::
optional
<
torch
::
Tensor
>
opt
ional
_csr2csc
,
torch
::
optional
<
torch
::
Tensor
>
opt_value
,
torch
::
optional
<
torch
::
Tensor
>
opt_colptr
,
torch
::
optional
<
torch
::
Tensor
>
opt_csr2csc
,
torch
::
Tensor
mat
)
{
// Since we cannot return an *optional* gradient, we need to convert
// `optional_value` to an empty sized tensor first :(
auto
value
=
torch
::
Tensor
();
if
(
optional_value
.
has_value
())
value
=
optional_value
.
value
();
return
SPMMSum
::
apply
(
optional_row
,
rowptr
,
col
,
value
,
optional_colptr
,
optional_csr2csc
,
mat
)[
0
];
// `opt_value` to an empty sized tensor first :(
auto
value
=
opt_value
.
has_value
()
?
opt_value
.
value
()
:
torch
::
Tensor
();
return
SPMMSum
::
apply
(
opt_row
,
rowptr
,
col
,
value
,
opt_colptr
,
opt_csr2csc
,
mat
)[
0
];
}
torch
::
Tensor
spmm_mean
(
torch
::
optional
<
torch
::
Tensor
>
opt_row
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
opt_value
,
torch
::
optional
<
torch
::
Tensor
>
opt_rowcount
,
torch
::
optional
<
torch
::
Tensor
>
opt_colptr
,
torch
::
optional
<
torch
::
Tensor
>
opt_csr2csc
,
torch
::
Tensor
mat
)
{
auto
value
=
opt_value
.
has_value
()
?
opt_value
.
value
()
:
torch
::
Tensor
();
return
SPMMMean
::
apply
(
opt_row
,
rowptr
,
col
,
value
,
opt_rowcount
,
opt_colptr
,
opt_csr2csc
,
mat
)[
0
];
}
static
auto
registry
=
torch
::
RegisterOperators
().
op
(
"torch_sparse::spmm_sum"
,
&
spmm_sum
);
static
auto
registry
=
torch
::
RegisterOperators
()
.
op
(
"torch_sparse::spmm_sum"
,
&
spmm_sum
)
.
op
(
"torch_sparse::spmm_mean"
,
&
spmm_mean
);
test/test_matmul.py
View file @
28cb8de4
...
...
@@ -10,7 +10,7 @@ import torch_scatter
from
.utils
import
devices
,
grad_dtypes
reductions
=
[
'sum'
,
'mean'
,
'min'
,
'max'
]
reductions
=
[
'sum'
]
reductions
=
[
'sum'
,
'mean'
]
@
pytest
.
mark
.
parametrize
(
'dtype,device,reduce'
,
...
...
torch_sparse/matmul.py
View file @
28cb8de4
...
...
@@ -19,7 +19,18 @@ except OSError:
raise
ImportError
return
mat
def
spmm_mean_placeholder
(
row
:
Optional
[
torch
.
Tensor
],
rowptr
:
torch
.
Tensor
,
col
:
torch
.
Tensor
,
value
:
Optional
[
torch
.
Tensor
],
rowcount
:
Optional
[
torch
.
Tensor
],
colptr
:
Optional
[
torch
.
Tensor
],
csr2csc
:
Optional
[
torch
.
Tensor
],
mat
:
torch
.
Tensor
)
->
torch
.
Tensor
:
raise
ImportError
return
mat
torch
.
ops
.
torch_sparse
.
spmm_sum
=
spmm_sum_placeholder
torch
.
ops
.
torch_sparse
.
spmm_mean
=
spmm_mean_placeholder
@
torch
.
jit
.
script
...
...
@@ -47,11 +58,35 @@ def spmm_add(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
return
spmm_sum
(
src
,
other
)
@
torch
.
jit
.
script
def
spmm_mean
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
torch
.
Tensor
:
rowptr
,
col
,
value
=
src
.
csr
()
row
=
src
.
storage
.
_row
rowcount
=
src
.
storage
.
_rowcount
csr2csc
=
src
.
storage
.
_csr2csc
colptr
=
src
.
storage
.
_colptr
if
value
is
not
None
and
value
.
requires_grad
:
row
=
src
.
storage
.
row
()
if
other
.
requires_grad
:
row
=
src
.
storage
.
row
()
rowcount
=
src
.
storage
.
rowcount
()
csr2csc
=
src
.
storage
.
csr2csc
()
colptr
=
src
.
storage
.
colptr
()
return
torch
.
ops
.
torch_sparse
.
spmm_mean
(
row
,
rowptr
,
col
,
value
,
rowcount
,
colptr
,
csr2csc
,
other
)
@
torch
.
jit
.
script
def
spmm
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
,
reduce
:
str
=
"sum"
)
->
torch
.
Tensor
:
if
reduce
==
'sum'
or
reduce
==
'add'
:
return
spmm_sum
(
src
,
other
)
elif
reduce
==
'mean'
:
return
spmm_mean
(
src
,
other
)
else
:
raise
ValueError
...
...
torch_sparse/storage.py
View file @
28cb8de4
...
...
@@ -274,7 +274,7 @@ class SparseStorage(object):
return
rowcount
rowptr
=
self
.
rowptr
()
rowcount
=
rowptr
[
1
:]
-
rowptr
[
1
:]
rowcount
=
rowptr
[
1
:]
-
rowptr
[:
-
1
]
self
.
_rowcount
=
rowcount
return
rowcount
...
...
@@ -306,7 +306,7 @@ class SparseStorage(object):
colptr
=
self
.
_colptr
if
colptr
is
not
None
:
colcount
=
colptr
[
1
:]
-
colptr
[
1
:]
colcount
=
colptr
[
1
:]
-
colptr
[:
-
1
]
else
:
colcount
=
scatter_add
(
torch
.
ones_like
(
self
.
_col
),
self
.
_col
,
dim_size
=
self
.
_sparse_sizes
[
1
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment