Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
87c88d95
Commit
87c88d95
authored
Aug 26, 2021
by
rusty1s
Browse files
enable autocast
parent
bdd1ced8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
2 deletions
+16
-2
torch_sparse/matmul.py
torch_sparse/matmul.py
+16
-2
No files found.
torch_sparse/matmul.py
View file @
87c88d95
...
...
@@ -11,6 +11,9 @@ def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
csr2csc
=
src
.
storage
.
_csr2csc
colptr
=
src
.
storage
.
_colptr
if
value
is
not
None
:
value
=
value
.
to
(
other
.
dtype
)
if
value
is
not
None
and
value
.
requires_grad
:
row
=
src
.
storage
.
row
()
...
...
@@ -35,6 +38,9 @@ def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
csr2csc
=
src
.
storage
.
_csr2csc
colptr
=
src
.
storage
.
_colptr
if
value
is
not
None
:
value
=
value
.
to
(
other
.
dtype
)
if
value
is
not
None
and
value
.
requires_grad
:
row
=
src
.
storage
.
row
()
...
...
@@ -51,12 +57,20 @@ def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
def
spmm_min
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
rowptr
,
col
,
value
=
src
.
csr
()
if
value
is
not
None
:
value
=
value
.
to
(
other
.
dtype
)
return
torch
.
ops
.
torch_sparse
.
spmm_min
(
rowptr
,
col
,
value
,
other
)
def
spmm_max
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
rowptr
,
col
,
value
=
src
.
csr
()
if
value
is
not
None
:
value
=
value
.
to
(
other
.
dtype
)
return
torch
.
ops
.
torch_sparse
.
spmm_max
(
rowptr
,
col
,
value
,
other
)
...
...
@@ -81,8 +95,8 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
value
=
valueA
if
valueA
is
not
None
and
valueA
.
dtype
==
torch
.
half
:
valueA
=
valueA
.
to
(
torch
.
float
)
if
valueB
is
not
None
and
valueB
.
dtype
==
torch
.
half
:
valueB
=
valueB
.
to
(
torch
.
float
)
if
valueB
is
not
None
:
valueB
=
valueB
.
to
(
valueA
.
dtype
)
M
,
K
=
src
.
sparse_size
(
0
),
other
.
sparse_size
(
1
)
rowptrC
,
colC
,
valueC
=
torch
.
ops
.
torch_sparse
.
spspmm_sum
(
rowptrA
,
colA
,
valueA
,
rowptrB
,
colB
,
valueB
,
K
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment