Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
238efb11
Commit
238efb11
authored
Mar 21, 2019
by
rusty1s
Browse files
faster spspmm backward + cleanup
parent
5586d7ae
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
262 additions
and
44 deletions
+262
-44
cpu/spspmm.cpp
cpu/spspmm.cpp
+66
-0
cuda/spspmm.cpp
cuda/spspmm.cpp
+19
-2
cuda/spspmm_kernel.cu
cuda/spspmm_kernel.cu
+70
-1
setup.py
setup.py
+5
-4
test/test_transpose.py
test/test_transpose.py
+18
-1
torch_sparse/__init__.py
torch_sparse/__init__.py
+6
-2
torch_sparse/convert.py
torch_sparse/convert.py
+18
-0
torch_sparse/spspmm.py
torch_sparse/spspmm.py
+34
-30
torch_sparse/transpose.py
torch_sparse/transpose.py
+26
-4
No files found.
cpu/spspmm.cpp
0 → 100644
View file @
238efb11
#include <torch/extension.h>
at
::
Tensor
degree
(
at
::
Tensor
row
,
int64_t
num_nodes
)
{
auto
zero
=
at
::
zeros
(
num_nodes
,
row
.
options
());
auto
one
=
at
::
ones
(
row
.
size
(
0
),
row
.
options
());
return
zero
.
scatter_add_
(
0
,
row
,
one
);
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
to_csr
(
at
::
Tensor
row
,
at
::
Tensor
col
,
int64_t
num_nodes
)
{
// Assert already coalesced input.
row
=
degree
(
row
,
num_nodes
).
cumsum
(
0
);
row
=
at
::
cat
({
at
::
zeros
(
1
,
row
.
options
()),
row
},
0
);
// Prepend zero.
return
std
::
make_tuple
(
row
,
col
);
}
at
::
Tensor
spspmm_bw
(
at
::
Tensor
index
,
at
::
Tensor
indexA
,
at
::
Tensor
valueA
,
at
::
Tensor
indexB
,
at
::
Tensor
valueB
,
size_t
rowA_max
,
size_t
rowB_max
)
{
int64_t
*
index_data
=
index
.
data
<
int64_t
>
();
auto
value
=
at
::
zeros
(
index
.
size
(
1
),
valueA
.
options
());
at
::
Tensor
rowA
,
colA
;
std
::
tie
(
rowA
,
colA
)
=
to_csr
(
indexA
[
0
],
indexA
[
1
],
rowA_max
);
int64_t
*
rowA_data
=
rowA
.
data
<
int64_t
>
();
int64_t
*
colA_data
=
colA
.
data
<
int64_t
>
();
at
::
Tensor
rowB
,
colB
;
std
::
tie
(
rowB
,
colB
)
=
to_csr
(
indexB
[
0
],
indexB
[
1
],
rowB_max
);
int64_t
*
rowB_data
=
rowB
.
data
<
int64_t
>
();
int64_t
*
colB_data
=
colB
.
data
<
int64_t
>
();
AT_DISPATCH_FLOATING_TYPES
(
valueA
.
type
(),
"spspmm_bw"
,
[
&
]
{
scalar_t
*
value_data
=
value
.
data
<
scalar_t
>
();
scalar_t
*
valueA_data
=
valueA
.
data
<
scalar_t
>
();
scalar_t
*
valueB_data
=
valueB
.
data
<
scalar_t
>
();
for
(
int64_t
e
=
0
;
e
<
value
.
size
(
0
);
e
++
)
{
int64_t
i
=
index_data
[
e
],
j
=
index_data
[
value
.
size
(
0
)
+
e
];
for
(
ptrdiff_t
dA
=
rowA_data
[
i
];
dA
<
rowA_data
[
i
+
1
];
dA
++
)
{
int64_t
cA
=
colA_data
[
dA
];
for
(
ptrdiff_t
dB
=
rowB_data
[
j
];
dB
<
rowB_data
[
j
+
1
];
dB
++
)
{
int64_t
cB
=
colB_data
[
dB
];
if
(
cA
==
cB
)
{
value_data
[
e
]
+=
valueA_data
[
dA
]
*
valueB_data
[
dB
];
}
if
(
cB
>=
cA
)
{
break
;
}
}
}
}
});
return
value
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"spspmm_bw"
,
&
spspmm_bw
,
"Sparse-Sparse Matrix Multiplication Backward (CPU)"
);
}
cuda/spspmm.cpp
View file @
238efb11
...
@@ -4,11 +4,14 @@
...
@@ -4,11 +4,14 @@
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
spspmm_cuda
(
at
::
Tensor
indexA
,
at
::
Tensor
valueA
,
at
::
Tensor
indexB
,
spspmm_cuda
(
at
::
Tensor
indexA
,
at
::
Tensor
valueA
,
at
::
Tensor
indexB
,
at
::
Tensor
valueB
,
int
m
,
int
k
,
int
n
);
at
::
Tensor
valueB
,
size_t
m
,
size_t
k
,
size_t
n
);
at
::
Tensor
spspmm_bw_cuda
(
at
::
Tensor
index
,
at
::
Tensor
indexA
,
at
::
Tensor
valueA
,
at
::
Tensor
indexB
,
at
::
Tensor
valueB
,
size_t
rowA_max
,
size_t
rowB_max
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
spspmm
(
at
::
Tensor
indexA
,
at
::
Tensor
valueA
,
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
spspmm
(
at
::
Tensor
indexA
,
at
::
Tensor
valueA
,
at
::
Tensor
indexB
,
at
::
Tensor
valueB
,
at
::
Tensor
indexB
,
at
::
Tensor
valueB
,
int
m
,
int
k
,
in
t
n
)
{
size_t
m
,
size_t
k
,
size_
t
n
)
{
CHECK_CUDA
(
indexA
);
CHECK_CUDA
(
indexA
);
CHECK_CUDA
(
valueA
);
CHECK_CUDA
(
valueA
);
CHECK_CUDA
(
indexB
);
CHECK_CUDA
(
indexB
);
...
@@ -16,6 +19,20 @@ std::tuple<at::Tensor, at::Tensor> spspmm(at::Tensor indexA, at::Tensor valueA,
...
@@ -16,6 +19,20 @@ std::tuple<at::Tensor, at::Tensor> spspmm(at::Tensor indexA, at::Tensor valueA,
return
spspmm_cuda
(
indexA
,
valueA
,
indexB
,
valueB
,
m
,
k
,
n
);
return
spspmm_cuda
(
indexA
,
valueA
,
indexB
,
valueB
,
m
,
k
,
n
);
}
}
at
::
Tensor
spspmm_bw
(
at
::
Tensor
index
,
at
::
Tensor
indexA
,
at
::
Tensor
valueA
,
at
::
Tensor
indexB
,
at
::
Tensor
valueB
,
size_t
rowA_max
,
size_t
rowB_max
)
{
CHECK_CUDA
(
index
);
CHECK_CUDA
(
indexA
);
CHECK_CUDA
(
valueA
);
CHECK_CUDA
(
indexB
);
CHECK_CUDA
(
valueB
);
return
spspmm_bw_cuda
(
index
,
indexA
,
valueA
,
indexB
,
valueB
,
rowA_max
,
rowB_max
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"spspmm"
,
&
spspmm
,
"Sparse-Sparse Matrix Multiplication (CUDA)"
);
m
.
def
(
"spspmm"
,
&
spspmm
,
"Sparse-Sparse Matrix Multiplication (CUDA)"
);
m
.
def
(
"spspmm_bw"
,
&
spspmm_bw
,
"Sparse-Sparse Matrix Multiplication Backward (CUDA)"
);
}
}
cuda/spspmm_kernel.cu
View file @
238efb11
...
@@ -2,6 +2,9 @@
...
@@ -2,6 +2,9 @@
#include <cusparse.h>
#include <cusparse.h>
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
#define CSRGEMM(TYPE, ...) \
#define CSRGEMM(TYPE, ...) \
[&] { \
[&] { \
const at::Type &the_type = TYPE; \
const at::Type &the_type = TYPE; \
...
@@ -29,7 +32,7 @@ static void init_cusparse() {
...
@@ -29,7 +32,7 @@ static void init_cusparse() {
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
spspmm_cuda
(
at
::
Tensor
indexA
,
at
::
Tensor
valueA
,
at
::
Tensor
indexB
,
spspmm_cuda
(
at
::
Tensor
indexA
,
at
::
Tensor
valueA
,
at
::
Tensor
indexB
,
at
::
Tensor
valueB
,
int
m
,
int
k
,
in
t
n
)
{
at
::
Tensor
valueB
,
size_t
m
,
size_t
k
,
size_
t
n
)
{
cudaSetDevice
(
indexA
.
get_device
());
cudaSetDevice
(
indexA
.
get_device
());
init_cusparse
();
init_cusparse
();
...
@@ -90,3 +93,69 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
...
@@ -90,3 +93,69 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
return
std
::
make_tuple
(
indexC
,
valueC
);
return
std
::
make_tuple
(
indexC
,
valueC
);
}
}
at
::
Tensor
degree
(
at
::
Tensor
row
,
int64_t
num_nodes
)
{
auto
zero
=
at
::
zeros
(
num_nodes
,
row
.
options
());
auto
one
=
at
::
ones
(
row
.
size
(
0
),
row
.
options
());
return
zero
.
scatter_add_
(
0
,
row
,
one
);
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
to_csr
(
at
::
Tensor
row
,
at
::
Tensor
col
,
int64_t
num_nodes
)
{
// Assert already coalesced input.
row
=
degree
(
row
,
num_nodes
).
cumsum
(
0
);
row
=
at
::
cat
({
at
::
zeros
(
1
,
row
.
options
()),
row
},
0
);
// Prepend zero.
return
std
::
make_tuple
(
row
,
col
);
}
template
<
typename
scalar_t
>
__global__
void
spspmm_bw_kernel
(
const
int64_t
*
__restrict__
index
,
scalar_t
*
__restrict__
value
,
const
int64_t
*
__restrict__
rowA
,
const
int64_t
*
__restrict__
colA
,
const
scalar_t
*
__restrict__
valueA
,
const
int64_t
*
__restrict__
rowB
,
const
int64_t
*
__restrict__
colB
,
const
scalar_t
*
__restrict__
valueB
,
const
size_t
numel
)
{
const
size_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
ptrdiff_t
e
=
idx
;
e
<
numel
;
e
+=
stride
)
{
int64_t
i
=
index
[
e
],
j
=
index
[
numel
+
e
];
for
(
ptrdiff_t
dA
=
rowA
[
i
];
dA
<
rowA
[
i
+
1
];
dA
++
)
{
int64_t
cA
=
colA
[
dA
];
for
(
ptrdiff_t
dB
=
rowB
[
j
];
dB
<
rowB
[
j
+
1
];
dB
++
)
{
int64_t
cB
=
colB
[
dB
];
if
(
cA
==
cB
)
{
value
[
e
]
+=
valueA
[
dA
]
*
valueB
[
dB
];
}
if
(
cB
>=
cA
)
{
break
;
}
}
}
}
}
at
::
Tensor
spspmm_bw_cuda
(
at
::
Tensor
index
,
at
::
Tensor
indexA
,
at
::
Tensor
valueA
,
at
::
Tensor
indexB
,
at
::
Tensor
valueB
,
size_t
rowA_max
,
size_t
rowB_max
)
{
cudaSetDevice
(
index
.
get_device
());
auto
value
=
at
::
zeros
(
index
.
size
(
1
),
valueA
.
options
());
at
::
Tensor
rowA
,
colA
;
std
::
tie
(
rowA
,
colA
)
=
to_csr
(
indexA
[
0
],
indexA
[
1
],
rowA_max
);
at
::
Tensor
rowB
,
colB
;
std
::
tie
(
rowB
,
colB
)
=
to_csr
(
indexB
[
0
],
indexB
[
1
],
rowB_max
);
AT_DISPATCH_FLOATING_TYPES
(
valueA
.
type
(),
"spspmm_bw"
,
[
&
]
{
spspmm_bw_kernel
<
scalar_t
><<<
BLOCKS
(
value
.
numel
()),
THREADS
>>>
(
index
.
data
<
int64_t
>
(),
value
.
data
<
scalar_t
>
(),
rowA
.
data
<
int64_t
>
(),
colA
.
data
<
int64_t
>
(),
valueA
.
data
<
scalar_t
>
(),
rowB
.
data
<
int64_t
>
(),
colB
.
data
<
int64_t
>
(),
valueB
.
data
<
scalar_t
>
(),
value
.
numel
());
});
return
value
;
}
setup.py
View file @
238efb11
import
platform
import
platform
from
setuptools
import
setup
,
find_packages
from
setuptools
import
setup
,
find_packages
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
,
CUDA_HOME
import
torch
from
torch.utils.cpp_extension
import
CppExtension
,
CUDAExtension
,
CUDA_HOME
__version__
=
'0.
2.4
'
__version__
=
'0.
3.0
'
url
=
'https://github.com/rusty1s/pytorch_sparse'
url
=
'https://github.com/rusty1s/pytorch_sparse'
install_requires
=
[
'scipy'
]
install_requires
=
[
'scipy'
]
setup_requires
=
[
'pytest-runner'
]
setup_requires
=
[
'pytest-runner'
]
tests_require
=
[
'pytest'
,
'pytest-cov'
]
tests_require
=
[
'pytest'
,
'pytest-cov'
]
ext_modules
=
[]
ext_modules
=
[
CppExtension
(
'torch_sparse.spspmm_cpu'
,
[
'cpu/spspmm.cpp'
])
]
cmdclass
=
{}
cmdclass
=
{}
if
CUDA_HOME
is
not
None
:
if
CUDA_HOME
is
not
None
:
...
@@ -25,7 +26,7 @@ if CUDA_HOME is not None:
...
@@ -25,7 +26,7 @@ if CUDA_HOME is not None:
CUDAExtension
(
'torch_sparse.unique_cuda'
,
CUDAExtension
(
'torch_sparse.unique_cuda'
,
[
'cuda/unique.cpp'
,
'cuda/unique_kernel.cu'
]),
[
'cuda/unique.cpp'
,
'cuda/unique_kernel.cu'
]),
]
]
cmdclass
[
'build_ext'
]
=
BuildExtension
cmdclass
[
'build_ext'
]
=
torch
.
utils
.
cpp_extension
.
BuildExtension
setup
(
setup
(
name
=
'torch_sparse'
,
name
=
'torch_sparse'
,
...
...
test/test_transpose.py
View file @
238efb11
from
itertools
import
product
import
pytest
import
torch
import
torch
from
torch_sparse
import
transpose
from
torch_sparse
import
transpose
,
transpose_matrix
from
.utils
import
dtypes
,
devices
,
tensor
def
test_transpose
():
def
test_transpose
():
...
@@ -11,3 +16,15 @@ def test_transpose():
...
@@ -11,3 +16,15 @@ def test_transpose():
index
,
value
=
transpose
(
index
,
value
,
m
=
3
,
n
=
2
)
index
,
value
=
transpose
(
index
,
value
,
m
=
3
,
n
=
2
)
assert
index
.
tolist
()
==
[[
0
,
0
,
1
,
1
],
[
1
,
2
,
0
,
1
]]
assert
index
.
tolist
()
==
[[
0
,
0
,
1
,
1
],
[
1
,
2
,
0
,
1
]]
assert
value
.
tolist
()
==
[[
7
,
9
],
[
5
,
6
],
[
6
,
8
],
[
3
,
4
]]
assert
value
.
tolist
()
==
[[
7
,
9
],
[
5
,
6
],
[
6
,
8
],
[
3
,
4
]]
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_transpose_matrix
(
dtype
,
device
):
row
=
torch
.
tensor
([
1
,
0
,
1
,
2
],
device
=
device
)
col
=
torch
.
tensor
([
0
,
1
,
1
,
0
],
device
=
device
)
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
value
=
tensor
([
1
,
2
,
3
,
4
],
dtype
,
device
)
index
,
value
=
transpose_matrix
(
index
,
value
,
m
=
3
,
n
=
2
)
assert
index
.
tolist
()
==
[[
0
,
0
,
1
,
1
],
[
1
,
2
,
0
,
1
]]
assert
value
.
tolist
()
==
[
1
,
4
,
2
,
3
]
torch_sparse/__init__.py
View file @
238efb11
from
.convert
import
to_scipy
,
from_scipy
from
.coalesce
import
coalesce
from
.coalesce
import
coalesce
from
.transpose
import
transpose
from
.transpose
import
transpose
,
transpose_matrix
from
.eye
import
eye
from
.eye
import
eye
from
.spmm
import
spmm
from
.spmm
import
spmm
from
.spspmm
import
spspmm
from
.spspmm
import
spspmm
__version__
=
'0.
2.4
'
__version__
=
'0.
3.0
'
__all__
=
[
__all__
=
[
'__version__'
,
'__version__'
,
'to_scipy'
,
'from_scipy'
,
'coalesce'
,
'coalesce'
,
'transpose'
,
'transpose'
,
'transpose_matrix'
,
'eye'
,
'eye'
,
'spmm'
,
'spmm'
,
'spspmm'
,
'spspmm'
,
...
...
torch_sparse/convert.py
0 → 100644
View file @
238efb11
import
numpy
as
np
import
scipy.sparse
import
torch
from
torch
import
from_numpy
def
to_scipy
(
index
,
value
,
m
,
n
):
assert
not
index
.
is_cuda
and
not
value
.
is_cuda
(
row
,
col
),
data
=
index
.
detach
(),
value
.
detach
()
return
scipy
.
sparse
.
coo_matrix
((
data
,
(
row
,
col
)),
(
m
,
n
))
def
from_scipy
(
A
):
A
=
A
.
tocoo
()
row
,
col
,
value
=
A
.
row
.
astype
(
np
.
int64
),
A
.
col
.
astype
(
np
.
int64
),
A
.
data
row
,
col
,
value
=
from_numpy
(
row
),
from_numpy
(
col
),
from_numpy
(
value
)
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
return
index
,
value
torch_sparse/spspmm.py
View file @
238efb11
import
torch
import
torch
from
torch
import
from_numpy
from
torch_sparse
import
transpose_matrix
,
to_scipy
,
from_scipy
import
numpy
as
np
import
scipy.sparse
import
torch_sparse.spspmm_cpu
from
torch_sparse
import
transpose
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
import
torch_sparse.spspmm_cuda
import
torch_sparse.spspmm_cuda
...
@@ -38,19 +37,36 @@ class SpSpMM(torch.autograd.Function):
...
@@ -38,19 +37,36 @@ class SpSpMM(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_indexC
,
grad_valueC
):
def
backward
(
ctx
,
grad_indexC
,
grad_valueC
):
m
,
k
,
n
=
ctx
.
m
,
ctx
.
k
,
ctx
.
n
m
,
k
=
ctx
.
m
,
ctx
.
k
n
=
ctx
.
n
indexA
,
valueA
,
indexB
,
valueB
,
indexC
=
ctx
.
saved_tensors
indexA
,
valueA
,
indexB
,
valueB
,
indexC
=
ctx
.
saved_tensors
grad_valueA
=
grad_valueB
=
None
grad_valueA
=
grad_valueB
=
None
if
not
grad_valueC
.
is_cuda
:
if
ctx
.
needs_input_grad
[
1
]
or
ctx
.
needs_input_grad
[
1
]:
grad_valueC
=
grad_valueC
.
clone
()
if
ctx
.
needs_input_grad
[
1
]:
grad_valueA
=
torch_sparse
.
spspmm_cpu
.
spspmm_bw
(
indexA
,
indexC
.
detach
(),
grad_valueC
,
indexB
.
detach
(),
valueB
,
m
,
k
)
if
ctx
.
needs_input_grad
[
3
]:
indexA
,
valueA
=
transpose_matrix
(
indexA
,
valueA
,
m
,
k
)
indexC
,
grad_valueC
=
transpose_matrix
(
indexC
,
grad_valueC
,
m
,
n
)
grad_valueB
=
torch_sparse
.
spspmm_cpu
.
spspmm_bw
(
indexB
,
indexA
.
detach
(),
valueA
,
indexC
.
detach
(),
grad_valueC
,
k
,
n
)
else
:
if
ctx
.
needs_input_grad
[
1
]:
if
ctx
.
needs_input_grad
[
1
]:
indexB_T
,
valueB_T
=
transpose
(
indexB
,
valueB
,
k
,
n
)
grad_valueA
=
torch_sparse
.
spspmm_cuda
.
spspmm_bw
(
grad_indexA
,
grad_valueA
=
mm
(
indexC
,
grad_valueC
,
indexB_T
,
indexA
,
indexC
.
detach
(),
grad_valueC
.
clone
(),
valueB_T
,
m
,
n
,
k
)
indexB
.
detach
(),
valueB
,
m
,
k
)
grad_valueA
=
lift
(
grad_indexA
,
grad_valueA
,
indexA
,
k
)
if
ctx
.
needs_input_grad
[
3
]:
if
ctx
.
needs_input_grad
[
3
]:
indexA_T
,
valueA_T
=
transpose
(
indexA
,
valueA
,
m
,
k
)
indexA_T
,
valueA_T
=
transpose
_matrix
(
indexA
,
valueA
,
m
,
k
)
grad_indexB
,
grad_valueB
=
mm
(
indexA_T
,
valueA_T
,
indexC
,
grad_indexB
,
grad_valueB
=
mm
(
indexA_T
,
valueA_T
,
indexC
,
grad_valueC
,
k
,
m
,
n
)
grad_valueC
,
k
,
m
,
n
)
grad_valueB
=
lift
(
grad_indexB
,
grad_valueB
,
indexB
,
n
)
grad_valueB
=
lift
(
grad_indexB
,
grad_valueB
,
indexB
,
n
)
...
@@ -67,23 +83,11 @@ def mm(indexA, valueA, indexB, valueB, m, k, n):
...
@@ -67,23 +83,11 @@ def mm(indexA, valueA, indexB, valueB, m, k, n):
A
=
to_scipy
(
indexA
,
valueA
,
m
,
k
)
A
=
to_scipy
(
indexA
,
valueA
,
m
,
k
)
B
=
to_scipy
(
indexB
,
valueB
,
k
,
n
)
B
=
to_scipy
(
indexB
,
valueB
,
k
,
n
)
indexC
,
valueC
=
from_scipy
(
A
.
tocsr
().
dot
(
B
.
tocsr
()
)
.
tocoo
()
)
C
=
A
.
dot
(
B
).
tocoo
()
.
tocsr
().
tocoo
()
# Force coalesce.
indexC
,
valueC
=
from_scipy
(
C
)
return
indexC
,
valueC
return
indexC
,
valueC
def
to_scipy
(
index
,
value
,
m
,
n
):
(
row
,
col
),
data
=
index
.
detach
(),
value
.
detach
()
return
scipy
.
sparse
.
coo_matrix
((
data
,
(
row
,
col
)),
(
m
,
n
))
def
from_scipy
(
A
):
row
,
col
,
value
=
A
.
row
.
astype
(
np
.
int64
),
A
.
col
.
astype
(
np
.
int64
),
A
.
data
row
,
col
,
value
=
from_numpy
(
row
),
from_numpy
(
col
),
from_numpy
(
value
)
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
return
index
,
value
def
lift
(
indexA
,
valueA
,
indexB
,
n
):
# pragma: no cover
def
lift
(
indexA
,
valueA
,
indexB
,
n
):
# pragma: no cover
idxA
=
indexA
[
0
]
*
n
+
indexA
[
1
]
idxA
=
indexA
[
0
]
*
n
+
indexA
[
1
]
idxB
=
indexB
[
0
]
*
n
+
indexB
[
1
]
idxB
=
indexB
[
0
]
*
n
+
indexB
[
1
]
...
...
torch_sparse/transpose.py
View file @
238efb11
import
torch
import
torch
from
torch_sparse
import
coalesce
from
torch_sparse
import
to_scipy
,
from_scipy
,
coalesce
def
transpose
(
index
,
value
,
m
,
n
):
def
transpose
(
index
,
value
,
m
,
n
):
"""Transposes dimensions 0 and 1 of a sparse
matrix
.
"""Transposes dimensions 0 and 1 of a sparse
tensor
.
Args:
Args:
index (:class:`LongTensor`): The index tensor of sparse matrix.
index (:class:`LongTensor`): The index tensor of sparse matrix.
...
@@ -16,7 +16,29 @@ def transpose(index, value, m, n):
...
@@ -16,7 +16,29 @@ def transpose(index, value, m, n):
row
,
col
=
index
row
,
col
=
index
index
=
torch
.
stack
([
col
,
row
],
dim
=
0
)
index
=
torch
.
stack
([
col
,
row
],
dim
=
0
)
index
,
value
=
coalesce
(
index
,
value
,
n
,
m
)
index
,
value
=
coalesce
(
index
,
value
,
n
,
m
)
return
index
,
value
def
transpose_matrix
(
index
,
value
,
m
,
n
):
"""Transposes dimensions 0 and 1 of a sparse matrix, where :args:`value` is
one-dimensional.
Args:
index (:class:`LongTensor`): The index tensor of sparse matrix.
value (:class:`Tensor`): The value tensor of sparse matrix.
m (int): The first dimension of sparse matrix.
n (int): The second dimension of sparse matrix.
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
assert
value
.
dim
()
==
1
if
not
index
.
is_cuda
:
mat
=
to_scipy
(
index
,
value
,
m
,
n
).
tocsc
()
(
col
,
row
),
value
=
from_scipy
(
mat
)
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
return
index
,
value
return
index
,
value
else
:
return
transpose
(
index
,
value
,
m
,
n
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment