Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
df5f7063
Commit
df5f7063
authored
Jan 22, 2020
by
rusty1s
Browse files
spmm backward implementation
parent
b3187f23
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
281 additions
and
39 deletions
+281
-39
cpu/spmm.cpp
cpu/spmm.cpp
+62
-14
test/test_matmul.py
test/test_matmul.py
+92
-0
torch_sparse/matmul.py
torch_sparse/matmul.py
+102
-7
torch_sparse/reduce.py
torch_sparse/reduce.py
+12
-15
torch_sparse/storage.py
torch_sparse/storage.py
+4
-3
torch_sparse/tensor.py
torch_sparse/tensor.py
+9
-0
No files found.
cpu/spmm.cpp
View file @
df5f7063
...
@@ -101,7 +101,6 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
...
@@ -101,7 +101,6 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
if
(
value_opt
.
has_value
())
if
(
value_opt
.
has_value
())
AT_ASSERTM
(
value_opt
.
value
().
dim
()
==
1
);
AT_ASSERTM
(
value_opt
.
value
().
dim
()
==
1
);
AT_ASSERTM
(
mat
.
dim
()
>=
2
,
"Input mismatch"
);
AT_ASSERTM
(
mat
.
dim
()
>=
2
,
"Input mismatch"
);
AT_ASSERTM
(
rowptr
.
numel
()
-
1
==
mat
.
size
(
-
2
),
"Input mismatch"
);
auto
sizes
=
mat
.
sizes
().
vec
();
auto
sizes
=
mat
.
sizes
().
vec
();
sizes
[
mat
.
dim
()
-
2
]
=
rowptr
.
numel
()
-
1
;
sizes
[
mat
.
dim
()
-
2
]
=
rowptr
.
numel
()
-
1
;
...
@@ -110,26 +109,26 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
...
@@ -110,26 +109,26 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
at
::
optional
<
at
::
Tensor
>
arg_out
=
at
::
nullopt
;
at
::
optional
<
at
::
Tensor
>
arg_out
=
at
::
nullopt
;
int64_t
*
arg_out_data
=
nullptr
;
int64_t
*
arg_out_data
=
nullptr
;
if
(
reduce2REDUCE
.
at
(
reduce
)
==
MIN
||
reduce2REDUCE
.
at
(
reduce
)
==
MAX
)
{
if
(
reduce2REDUCE
.
at
(
reduce
)
==
MIN
||
reduce2REDUCE
.
at
(
reduce
)
==
MAX
)
{
arg_out
=
at
::
full_like
(
out
,
mat
.
size
(
-
2
)
,
rowptr
.
options
());
arg_out
=
at
::
full_like
(
out
,
-
1
,
rowptr
.
options
());
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
}
}
auto
rowptr_data
=
rowptr
.
DATA_PTR
<
int64_t
>
();
auto
rowptr_data
=
rowptr
.
DATA_PTR
<
int64_t
>
();
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
int
N
=
rowptr
.
numel
()
-
1
;
auto
N
=
rowptr
.
numel
()
-
1
;
int
M
=
mat
.
size
(
-
2
);
auto
M
=
mat
.
size
(
-
2
);
int
K
=
mat
.
size
(
-
1
);
auto
K
=
mat
.
size
(
-
1
);
int
B
=
mat
.
numel
()
/
(
M
*
K
);
auto
B
=
mat
.
numel
()
/
(
M
*
K
);
AT_DISPATCH_ALL_TYPES
(
mat
.
scalar_type
(),
"spmm"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES
(
mat
.
scalar_type
(),
"spmm"
,
[
&
]
{
scalar_t
*
value_data
=
nullptr
;
scalar_t
*
value_data
=
nullptr
;
auto
mat_data
=
ou
t
.
DATA_PTR
<
scalar_t
>
();
auto
mat_data
=
ma
t
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
ma
t
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
ou
t
.
DATA_PTR
<
scalar_t
>
();
scalar_t
val
;
scalar_t
val
;
std
::
vector
<
scalar_t
>
vals
(
K
);
std
::
vector
<
scalar_t
>
vals
(
K
);
int64_t
row_start
,
row_end
,
c
ol_idx
;
int64_t
row_start
,
row_end
,
c
;
std
::
vector
<
int64_t
>
args
(
K
);
std
::
vector
<
int64_t
>
args
(
K
);
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
...
@@ -147,18 +146,17 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
...
@@ -147,18 +146,17 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
int
offset
=
b
*
M
*
K
;
int
offset
=
b
*
M
*
K
;
for
(
int
e
=
row_start
;
e
<
row_end
;
e
++
)
{
for
(
int
e
=
row_start
;
e
<
row_end
;
e
++
)
{
c
ol_idx
=
col_data
[
e
];
c
=
col_data
[
e
];
if
(
HAS_VAL
)
if
(
HAS_VAL
)
val
=
value_data
[
e
];
val
=
value_data
[
e
];
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
if
(
HAS_VAL
)
if
(
HAS_VAL
)
Reducer
<
scalar_t
,
REDUCE
>::
update
(
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
vals
[
k
],
val
*
mat_data
[
offset
+
c
ol_idx
*
K
+
k
],
&
vals
[
k
],
val
*
mat_data
[
offset
+
c
*
K
+
k
],
&
args
[
k
],
&
args
[
k
],
e
);
e
);
else
else
Reducer
<
scalar_t
,
REDUCE
>::
update
(
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
vals
[
k
],
mat_data
[
offset
+
col_idx
*
K
+
k
],
&
args
[
k
],
&
vals
[
k
],
mat_data
[
offset
+
c
*
K
+
k
],
&
args
[
k
],
e
);
e
);
}
}
}
}
offset
=
b
*
N
*
K
+
n
*
K
;
offset
=
b
*
N
*
K
+
n
*
K
;
...
@@ -175,6 +173,56 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
...
@@ -175,6 +173,56 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
return
std
::
make_tuple
(
out
,
arg_out
);
return
std
::
make_tuple
(
out
,
arg_out
);
}
}
at
::
Tensor
spmm_val_bw
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
Tensor
mat
,
at
::
Tensor
grad
,
std
::
string
reduce
)
{
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
col
);
CHECK_CPU
(
mat
);
CHECK_CPU
(
grad
);
mat
=
mat
.
contiguous
();
auto
M
=
rowptr
.
numel
()
-
1
;
auto
N
=
mat
.
size
(
-
2
);
auto
K
=
mat
.
size
(
-
1
);
auto
B
=
mat
.
numel
()
/
(
N
*
K
);
auto
out
=
at
::
zeros
(
col
.
sizes
(),
grad
.
options
());
auto
rowptr_data
=
rowptr
.
DATA_PTR
<
int64_t
>
();
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
AT_DISPATCH_ALL_TYPES
(
mat
.
scalar_type
(),
"spmm_val_bw"
,
[
&
]
{
auto
mat_data
=
mat
.
DATA_PTR
<
scalar_t
>
();
auto
grad_data
=
grad
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
scalar_t
val
;
int64_t
row_start
,
row_end
,
c
;
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
for
(
int
b
=
0
;
b
<
B
;
b
++
)
{
for
(
int
m
=
0
;
m
<
M
;
m
++
)
{
row_start
=
rowptr_data
[
m
],
row_end
=
rowptr_data
[
m
+
1
];
for
(
int
e
=
row_start
;
e
<
row_end
;
e
++
)
{
c
=
col_data
[
e
],
val
=
(
scalar_t
)
0
;
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
val
+=
mat_data
[
b
*
N
*
K
+
c
*
K
+
k
]
*
grad_data
[
b
*
M
*
K
+
m
*
K
+
k
];
}
if
(
REDUCE
==
MEAN
)
val
=
val
/
(
scalar_t
)(
row_end
-
row_start
);
out_data
[
e
]
+=
val
;
}
}
}
});
});
return
out
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"spmm"
,
&
spmm
,
"Sparse-Dense Matrix Multiplication (CPU)"
);
m
.
def
(
"spmm"
,
&
spmm
,
"Sparse-Dense Matrix Multiplication (CPU)"
);
m
.
def
(
"spmm_val_bw"
,
&
spmm_val_bw
,
"Sparse-Dense Matrix Multiplication Value Backward (CPU)"
);
}
}
test/test_matmul.py
0 → 100644
View file @
df5f7063
from
itertools
import
product
import
pytest
import
torch
from
torch.autograd
import
gradcheck
from
torch_sparse.matmul
import
matmul
from
torch_sparse.tensor
import
SparseTensor
import
torch_scatter
from
.utils
import
tensor
,
devices
,
dtypes
devices
=
[
'cpu'
]
dtypes
=
[
torch
.
float
]
reductions
=
[
'sum'
,
'mean'
,
'min'
,
'max'
]
# grad_reductions = ['sum', 'mean']
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_spmm_forward
(
dtype
,
device
):
src_dense
=
torch
.
randn
((
5
,
4
),
dtype
=
dtype
,
device
=
device
)
src
=
SparseTensor
.
from_dense
(
src_dense
)
src
.
requires_grad_
()
src_dense
=
src_dense
.
clone
().
requires_grad_
()
other
=
torch
.
randn
((
4
,
8
),
dtype
=
dtype
,
device
=
device
)
other
.
requires_grad_
()
out1
=
matmul
(
src
,
other
)
grad_out
=
torch
.
randn_like
(
out1
)
out1
.
backward
(
grad_out
)
other
.
grad
=
None
out2
=
torch
.
matmul
(
src_dense
,
other
)
out2
.
backward
(
grad_out
)
# assert torch.allclose(out1, out2)
# assert torch.allclose(src.storage.value.grad.view(5, 4), src_dense.grad)
@
pytest
.
mark
.
parametrize
(
'dtype,device,reduce'
,
product
(
dtypes
,
devices
,
reductions
))
def
test_spmm
(
dtype
,
device
,
reduce
):
src
=
torch
.
ones
((
5
,
4
),
dtype
=
dtype
,
device
=
device
)
src
[
2
]
=
0
src
=
SparseTensor
.
from_dense
(
src
).
requires_grad_
()
src
.
set_value_
(
None
)
other
=
torch
.
randn
((
2
,
4
,
2
),
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
(
row
,
col
),
value
=
src
.
coo
()
out1
=
other
.
index_select
(
-
2
,
col
)
# * value.unsqueeze(-1)
func
=
'add'
if
reduce
==
'sum'
else
reduce
out1
=
getattr
(
torch_scatter
,
f
'scatter_
{
func
}
'
)(
out1
,
row
,
dim
=-
2
)
out1
=
out1
[
0
]
if
isinstance
(
out1
,
tuple
)
else
out1
grad_out
=
torch
.
randn_like
(
out1
)
out1
.
backward
(
grad_out
)
# grad_value1 = value.grad
# value.grad = None
grad_other1
=
other
.
grad
other
.
grad
=
None
print
(
reduce
)
out2
=
matmul
(
src
,
other
,
reduce
)
out2
=
out2
[
0
]
if
isinstance
(
out2
,
tuple
)
else
out2
out2
.
backward
(
grad_out
)
# grad_value2 = value.grad
# value.grad = None
grad_other2
=
other
.
grad
other
.
grad
=
None
# assert torch.allclose(out1, out2)
# assert torch.allclose(grad_value1, grad_value2)
assert
torch
.
allclose
(
grad_other1
,
grad_other2
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_spmm_backward
(
dtype
,
device
):
src_dense
=
torch
.
randn
((
5
,
4
),
dtype
=
torch
.
double
,
device
=
device
)
src
=
SparseTensor
.
from_dense
(
src_dense
)
src
.
requires_grad_
()
other
=
torch
.
randn
((
4
,
8
),
dtype
=
torch
.
double
,
device
=
device
)
other
.
requires_grad_
()
# assert gradcheck(matmul, (src, other, "sum"))
torch_sparse/matmul.py
View file @
df5f7063
import
torch
import
torch
from
torch_sparse
import
spmm_cpu
from
torch_scatter
import
scatter_add
try
:
from
torch_sparse
import
spmm_cuda
except
ImportError
:
spmm_cuda
=
None
def
spmm
(
is_cuda
):
return
spmm_cuda
if
is_cuda
else
spmm_cpu
class
SPMM
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
index
,
rowcount
,
rowptr
,
colptr
,
csr2csc
,
value
,
mat
,
reduce
):
out
,
arg_out
=
spmm
(
mat
.
is_cuda
).
spmm
(
rowptr
,
index
[
1
],
value
,
mat
,
reduce
)
ctx
.
reduce
=
reduce
ctx
.
save_for_backward
(
index
,
rowcount
,
rowptr
,
colptr
,
csr2csc
,
value
,
mat
,
arg_out
)
if
reduce
==
'min'
or
reduce
==
'max'
:
return
out
,
arg_out
else
:
return
out
@
staticmethod
def
backward
(
ctx
,
grad_out
,
*
args
):
data
=
ctx
.
saved_tensors
index
,
rowcount
,
rowptr
,
colptr
,
csr2csc
,
value
,
mat
,
arg_out
=
data
grad_value
=
None
if
ctx
.
needs_input_grad
[
5
]:
if
ctx
.
reduce
in
[
'sum'
,
'add'
]:
grad_value
=
spmm
(
grad_out
.
is_cuda
).
spmm_val_bw
(
rowptr
,
index
[
1
],
mat
,
grad_out
,
ctx
.
reduce
)
if
ctx
.
reduce
==
'mean'
:
grad_value
=
spmm
(
grad_out
.
is_cuda
).
spmm_val_bw
(
rowptr
,
index
[
1
],
mat
,
grad_out
,
ctx
.
reduce
)
elif
ctx
.
reduce
in
[
'min'
,
'max'
]:
col
=
index
[
1
][
arg_out
.
flatten
()].
view_as
(
arg_out
)
out
=
mat
.
gather
(
-
2
,
col
).
mul_
(
grad_out
)
out
.
masked_fill_
(
arg_out
==
-
1
,
0
)
col
=
col
.
add_
(
rowptr
[:
-
1
].
view
(
-
1
,
1
))
grad_value
=
scatter_add
(
out
.
flatten
(),
col
.
flatten
(),
dim
=
0
,
dim_size
=
value
.
numel
())
grad_mat
=
None
if
ctx
.
needs_input_grad
[
6
]:
if
ctx
.
reduce
in
[
'sum'
,
'add'
]:
row
=
index
[
0
][
csr2csc
]
value
=
value
[
csr2csc
]
if
value
is
not
None
else
value
grad_mat
,
_
=
spmm
(
grad_out
.
is_cuda
).
spmm
(
colptr
,
row
,
value
,
grad_out
,
'sum'
)
elif
ctx
.
reduce
==
'mean'
:
count
=
rowcount
[
index
[
0
]].
to
(
mat
.
dtype
).
clamp_
(
min
=
1
)
value
=
count
.
pow_
(
-
1
)
if
value
is
None
else
value
/
count
row
=
index
[
0
][
csr2csc
]
value
=
value
[
csr2csc
]
if
value
is
not
None
else
value
grad_mat
,
_
=
spmm
(
grad_out
.
is_cuda
).
spmm
(
colptr
,
row
,
value
,
grad_out
,
'sum'
)
elif
ctx
.
reduce
in
[
'min'
,
'max'
]:
if
value
is
not
None
:
value
=
value
[
arg_out
.
flatten
()].
view_as
(
arg_out
)
value
=
value
.
mul_
(
grad_out
)
else
:
value
=
grad_out
value
.
masked_fill_
(
arg_out
==
-
1
,
0
)
col
=
index
[
1
][
arg_out
.
flatten
()].
view_as
(
arg_out
)
grad_mat
=
scatter_add
(
value
,
col
,
dim
=-
2
,
dim_size
=
mat
.
size
(
-
2
))
return
None
,
None
,
None
,
None
,
None
,
grad_value
,
grad_mat
,
None
def
matmul
(
src
,
other
,
reduce
=
'sum'
):
assert
src
.
dim
()
==
2
and
src
.
size
(
-
1
)
==
other
.
size
(
-
2
)
def
matmul
(
src
,
other
,
reduce
=
'add'
):
if
torch
.
is_tensor
(
other
):
if
torch
.
is_tensor
(
other
):
pass
assert
reduce
in
[
'sum'
,
'add'
,
'mean'
,
'min'
,
'max'
]
if
isinstance
(
other
,
src
.
__class__
):
(
index
,
value
),
rowptr
=
src
.
coo
(),
src
.
storage
.
rowptr
if
reduce
!=
'add'
:
raise
NotImplementedError
(
csr2csc
=
colptr
=
None
(
f
'Reduce argument "
{
reduce
}
" not implemented for sparse-'
if
other
.
requires_grad
and
reduce
in
[
'sum'
,
'add'
,
'mean'
]:
f
'sparse matrix multiplication'
))
csr2csc
,
colptr
=
src
.
storage
.
csr2csc
,
src
.
storage
.
colptr
rowcount
=
None
if
other
.
requires_grad
and
reduce
in
[
'mean'
]:
rowcount
=
src
.
storage
.
rowcount
return
SPMM
.
apply
(
index
,
rowcount
,
rowptr
,
colptr
,
csr2csc
,
value
,
other
,
reduce
)
elif
isinstance
(
other
,
src
.
__class__
):
assert
reduce
in
[
'sum'
,
'add'
]
raise
ValueError
torch_sparse/reduce.py
View file @
df5f7063
...
@@ -3,15 +3,14 @@ import torch_scatter
...
@@ -3,15 +3,14 @@ import torch_scatter
from
torch_scatter
import
segment_csr
from
torch_scatter
import
segment_csr
def
reduction
(
src
,
dim
=
None
,
reduce
=
'
add
'
,
deterministic
=
False
):
def
reduction
(
src
,
dim
=
None
,
reduce
=
'
sum
'
,
deterministic
=
False
):
assert
reduce
in
[
'
add
'
,
'mean'
,
'min'
,
'max'
]
assert
reduce
in
[
'
sum
'
,
'mean'
,
'min'
,
'max'
]
if
dim
is
None
and
src
.
has_value
():
if
dim
is
None
and
src
.
has_value
():
func
=
getattr
(
torch
,
'sum'
if
reduce
==
'add'
else
reduce
)
return
getattr
(
torch
,
reduce
)(
src
.
storage
.
value
)
return
func
(
src
.
storage
.
value
)
if
dim
is
None
and
not
src
.
has_value
():
if
dim
is
None
and
not
src
.
has_value
():
value
=
src
.
nnz
()
if
reduce
==
'
add
'
else
1
value
=
src
.
nnz
()
if
reduce
==
'
sum
'
else
1
return
torch
.
tensor
(
value
,
device
=
src
.
device
)
return
torch
.
tensor
(
value
,
device
=
src
.
device
)
dims
=
[
dim
]
if
isinstance
(
dim
,
int
)
else
dim
dims
=
[
dim
]
if
isinstance
(
dim
,
int
)
else
dim
...
@@ -24,25 +23,22 @@ def reduction(src, dim=None, reduce='add', deterministic=False):
...
@@ -24,25 +23,22 @@ def reduction(src, dim=None, reduce='add', deterministic=False):
dense_dims
=
tuple
(
set
([
d
-
1
for
d
in
dims
if
d
>
1
]))
dense_dims
=
tuple
(
set
([
d
-
1
for
d
in
dims
if
d
>
1
]))
if
len
(
sparse_dims
)
==
2
and
src
.
has_value
():
if
len
(
sparse_dims
)
==
2
and
src
.
has_value
():
func
=
getattr
(
torch
,
'sum'
if
reduce
==
'add'
else
reduce
)
return
getattr
(
torch
,
reduce
)(
value
,
dim
=
(
0
,
)
+
dense_dims
)
return
func
(
value
,
dim
=
(
0
,
)
+
dense_dims
)
if
len
(
sparse_dims
)
==
2
and
not
src
.
has_value
():
if
len
(
sparse_dims
)
==
2
and
not
src
.
has_value
():
value
=
src
.
nnz
()
if
reduce
==
'
add
'
else
1
value
=
src
.
nnz
()
if
reduce
==
'
sum
'
else
1
return
torch
.
tensor
(
value
,
device
=
src
.
device
)
return
torch
.
tensor
(
value
,
device
=
src
.
device
)
if
len
(
dense_dims
)
>
0
and
len
(
sparse_dims
)
==
0
:
# src.has_value()
if
len
(
dense_dims
)
>
0
and
len
(
sparse_dims
)
==
0
:
# src.has_value()
func
=
getattr
(
torch
,
'sum'
if
reduce
==
'add'
else
reduce
)
dense_dims
=
dense_dims
[
0
]
if
len
(
dense_dims
)
==
1
else
dense_dims
dense_dims
=
dense_dims
[
0
]
if
len
(
dense_dims
)
==
1
else
dense_dims
value
=
func
(
value
,
dim
=
dense_dims
)
value
=
getattr
(
torch
,
reduce
)
(
value
,
dim
=
dense_dims
)
if
isinstance
(
value
,
tuple
):
if
isinstance
(
value
,
tuple
):
return
(
src
.
set_value
(
value
[
0
],
layout
=
'csr'
),
)
+
value
[
1
:]
return
(
src
.
set_value
(
value
[
0
],
layout
=
'csr'
),
)
+
value
[
1
:]
return
src
.
set_value
(
value
,
layout
=
'csr'
)
return
src
.
set_value
(
value
,
layout
=
'csr'
)
if
len
(
dense_dims
)
>
0
and
len
(
sparse_dims
)
>
0
:
if
len
(
dense_dims
)
>
0
and
len
(
sparse_dims
)
>
0
:
func
=
getattr
(
torch
,
'sum'
if
reduce
==
'add'
else
reduce
)
dense_dims
=
dense_dims
[
0
]
if
len
(
dense_dims
)
==
1
else
dense_dims
dense_dims
=
dense_dims
[
0
]
if
len
(
dense_dims
)
==
1
else
dense_dims
value
=
func
(
value
,
dim
=
dense_dims
)
value
=
getattr
(
torch
,
reduce
)
(
value
,
dim
=
dense_dims
)
value
=
value
[
0
]
if
isinstance
(
value
,
tuple
)
else
value
value
=
value
[
0
]
if
isinstance
(
value
,
tuple
)
else
value
if
sparse_dims
[
0
]
==
1
and
src
.
has_value
():
if
sparse_dims
[
0
]
==
1
and
src
.
has_value
():
...
@@ -51,7 +47,7 @@ def reduction(src, dim=None, reduce='add', deterministic=False):
...
@@ -51,7 +47,7 @@ def reduction(src, dim=None, reduce='add', deterministic=False):
return
out
return
out
if
sparse_dims
[
0
]
==
1
and
not
src
.
has_value
():
if
sparse_dims
[
0
]
==
1
and
not
src
.
has_value
():
if
reduce
==
'
add
'
:
if
reduce
==
'
sum
'
:
return
src
.
storage
.
rowcount
.
to
(
torch
.
get_default_dtype
())
return
src
.
storage
.
rowcount
.
to
(
torch
.
get_default_dtype
())
elif
reduce
==
'min'
or
'max'
:
elif
reduce
==
'min'
or
'max'
:
# Return an additional `None` arg(min|max) tensor for consistency.
# Return an additional `None` arg(min|max) tensor for consistency.
...
@@ -68,13 +64,14 @@ def reduction(src, dim=None, reduce='add', deterministic=False):
...
@@ -68,13 +64,14 @@ def reduction(src, dim=None, reduce='add', deterministic=False):
return
out
return
out
if
sparse_dims
[
0
]
==
0
and
src
.
has_value
():
if
sparse_dims
[
0
]
==
0
and
src
.
has_value
():
reduce
=
'add'
if
reduce
==
'sum'
else
reduce
func
=
getattr
(
torch_scatter
,
f
'scatter_
{
reduce
}
'
)
func
=
getattr
(
torch_scatter
,
f
'scatter_
{
reduce
}
'
)
out
=
func
(
value
,
col
,
dim
=
0
,
dim_size
=
src
.
sparse_size
(
1
))
out
=
func
(
value
,
col
,
dim
=
0
,
dim_size
=
src
.
sparse_size
(
1
))
out
=
out
[
0
]
if
len
(
dense_dims
)
>
0
and
isinstance
(
out
,
tuple
)
else
out
out
=
out
[
0
]
if
len
(
dense_dims
)
>
0
and
isinstance
(
out
,
tuple
)
else
out
return
out
return
out
if
sparse_dims
[
0
]
==
0
and
not
src
.
has_value
():
if
sparse_dims
[
0
]
==
0
and
not
src
.
has_value
():
if
reduce
==
'
add
'
:
if
reduce
==
'
sum
'
:
return
src
.
storage
.
colcount
.
to
(
torch
.
get_default_dtype
())
return
src
.
storage
.
colcount
.
to
(
torch
.
get_default_dtype
())
elif
reduce
==
'min'
or
'max'
:
elif
reduce
==
'min'
or
'max'
:
# Return an additional `None` arg(min|max) tensor for consistency.
# Return an additional `None` arg(min|max) tensor for consistency.
...
@@ -84,7 +81,7 @@ def reduction(src, dim=None, reduce='add', deterministic=False):
...
@@ -84,7 +81,7 @@ def reduction(src, dim=None, reduce='add', deterministic=False):
def
sum
(
src
,
dim
=
None
,
deterministic
=
False
):
def
sum
(
src
,
dim
=
None
,
deterministic
=
False
):
return
reduction
(
src
,
dim
,
reduce
=
'
add
'
,
deterministic
=
deterministic
)
return
reduction
(
src
,
dim
,
reduce
=
'
sum
'
,
deterministic
=
deterministic
)
def
mean
(
src
,
dim
=
None
,
deterministic
=
False
):
def
mean
(
src
,
dim
=
None
,
deterministic
=
False
):
...
...
torch_sparse/storage.py
View file @
df5f7063
...
@@ -164,8 +164,9 @@ class SparseStorage(object):
...
@@ -164,8 +164,9 @@ class SparseStorage(object):
value
=
torch
.
full
((
self
.
nnz
(),
),
device
=
self
.
index
.
device
)
value
=
torch
.
full
((
self
.
nnz
(),
),
device
=
self
.
index
.
device
)
elif
torch
.
is_tensor
(
value
)
and
get_layout
(
layout
)
==
'csc'
:
elif
torch
.
is_tensor
(
value
)
and
get_layout
(
layout
)
==
'csc'
:
value
=
value
[
self
.
csc2csr
]
value
=
value
[
self
.
csc2csr
]
assert
value
.
device
==
self
.
index
.
device
if
torch
.
is_tensor
(
value
):
assert
value
.
size
(
0
)
==
self
.
index
.
size
(
1
)
assert
value
.
device
==
self
.
index
.
device
assert
value
.
size
(
0
)
==
self
.
index
.
size
(
1
)
self
.
_value
=
value
self
.
_value
=
value
return
self
return
self
...
@@ -268,7 +269,7 @@ class SparseStorage(object):
...
@@ -268,7 +269,7 @@ class SparseStorage(object):
@
cached_property
@
cached_property
def
colptr
(
self
):
def
colptr
(
self
):
if
self
.
_csr2csc
:
if
self
.
has
_csr2csc
()
:
func
=
rowptr_cuda
if
self
.
index
.
is_cuda
else
rowptr_cpu
func
=
rowptr_cuda
if
self
.
index
.
is_cuda
else
rowptr_cpu
return
func
.
rowptr
(
self
.
col
[
self
.
csr2csc
],
self
.
sparse_size
(
1
))
return
func
.
rowptr
(
self
.
col
[
self
.
csr2csc
],
self
.
sparse_size
(
1
))
else
:
else
:
...
...
torch_sparse/tensor.py
View file @
df5f7063
...
@@ -214,6 +214,15 @@ class SparseTensor(object):
...
@@ -214,6 +214,15 @@ class SparseTensor(object):
def
detach
(
self
):
def
detach
(
self
):
return
self
.
from_storage
(
self
.
storage
.
apply
(
lambda
x
:
x
.
detach
()))
return
self
.
from_storage
(
self
.
storage
.
apply
(
lambda
x
:
x
.
detach
()))
@
property
def
requires_grad
(
self
):
return
self
.
storage
.
value
.
requires_grad
if
self
.
has_value
()
else
False
def
requires_grad_
(
self
,
requires_grad
=
True
):
if
self
.
has_value
():
self
.
storage
.
value
.
requires_grad_
(
requires_grad
)
return
self
def
pin_memory
(
self
):
def
pin_memory
(
self
):
return
self
.
from_storage
(
self
.
storage
.
apply
(
lambda
x
:
x
.
pin_memory
()))
return
self
.
from_storage
(
self
.
storage
.
apply
(
lambda
x
:
x
.
pin_memory
()))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment