Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
e44a639f
Commit
e44a639f
authored
Feb 03, 2020
by
rusty1s
Browse files
spmm done
parent
bb1ba6b0
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
208 additions
and
168 deletions
+208
-168
csrc/spmm.cpp
csrc/spmm.cpp
+135
-1
test/test_matmul.py
test/test_matmul.py
+1
-4
test/utils.py
test/utils.py
+2
-0
torch_sparse/__init__.py
torch_sparse/__init__.py
+2
-1
torch_sparse/matmul.py
torch_sparse/matmul.py
+68
-162
No files found.
csrc/spmm.cpp
View file @
e44a639f
...
@@ -168,6 +168,122 @@ public:
...
@@ -168,6 +168,122 @@ public:
}
}
};
};
class
SPMMMin
:
public
torch
::
autograd
::
Function
<
SPMMMin
>
{
public:
static
variable_list
forward
(
AutogradContext
*
ctx
,
Variable
rowptr
,
Variable
col
,
Variable
value
,
Variable
mat
,
bool
has_value
)
{
torch
::
optional
<
torch
::
Tensor
>
opt_value
=
torch
::
nullopt
;
if
(
has_value
)
opt_value
=
value
;
auto
result
=
spmm_fw
(
rowptr
,
col
,
opt_value
,
mat
,
"min"
);
auto
out
=
std
::
get
<
0
>
(
result
);
auto
arg_out
=
std
::
get
<
1
>
(
result
).
value
();
ctx
->
saved_data
[
"has_value"
]
=
has_value
;
ctx
->
save_for_backward
({
col
,
value
,
mat
,
arg_out
});
ctx
->
mark_non_differentiable
({
arg_out
});
return
{
out
,
arg_out
};
}
static
variable_list
backward
(
AutogradContext
*
ctx
,
variable_list
grad_outs
)
{
auto
has_value
=
ctx
->
saved_data
[
"has_value"
].
toBool
();
auto
grad_out
=
grad_outs
[
0
];
auto
saved
=
ctx
->
get_saved_variables
();
auto
col
=
saved
[
0
],
value
=
saved
[
1
],
mat
=
saved
[
2
],
arg_out
=
saved
[
3
];
auto
invalid_arg_mask
=
arg_out
==
col
.
size
(
0
);
arg_out
=
arg_out
.
masked_fill
(
invalid_arg_mask
,
0
);
auto
grad_value
=
Variable
();
if
(
has_value
>
0
&&
torch
::
autograd
::
any_variable_requires_grad
({
value
}))
{
auto
ind
=
col
.
index_select
(
0
,
arg_out
.
flatten
()).
view_as
(
arg_out
);
auto
out
=
mat
.
gather
(
-
2
,
ind
);
out
.
mul_
(
grad_out
);
out
.
masked_fill_
(
invalid_arg_mask
,
0
);
grad_value
=
torch
::
zeros_like
(
value
);
grad_value
.
scatter_add_
(
0
,
arg_out
.
flatten
(),
out
.
flatten
());
}
auto
grad_mat
=
Variable
();
if
(
torch
::
autograd
::
any_variable_requires_grad
({
mat
}))
{
if
(
has_value
>
0
)
{
value
=
value
.
index_select
(
0
,
arg_out
.
flatten
()).
view_as
(
arg_out
);
value
.
mul_
(
grad_out
);
}
else
value
=
grad_out
;
value
.
masked_fill_
(
invalid_arg_mask
,
0
);
auto
ind
=
col
.
index_select
(
0
,
arg_out
.
flatten
()).
view_as
(
arg_out
);
grad_mat
=
torch
::
zeros_like
(
mat
);
grad_mat
.
scatter_add_
(
-
2
,
ind
,
value
);
}
return
{
Variable
(),
Variable
(),
grad_value
,
grad_mat
,
Variable
()};
}
};
class
SPMMMax
:
public
torch
::
autograd
::
Function
<
SPMMMax
>
{
public:
static
variable_list
forward
(
AutogradContext
*
ctx
,
Variable
rowptr
,
Variable
col
,
Variable
value
,
Variable
mat
,
bool
has_value
)
{
torch
::
optional
<
torch
::
Tensor
>
opt_value
=
torch
::
nullopt
;
if
(
has_value
)
opt_value
=
value
;
auto
result
=
spmm_fw
(
rowptr
,
col
,
opt_value
,
mat
,
"max"
);
auto
out
=
std
::
get
<
0
>
(
result
);
auto
arg_out
=
std
::
get
<
1
>
(
result
).
value
();
ctx
->
saved_data
[
"has_value"
]
=
has_value
;
ctx
->
save_for_backward
({
col
,
value
,
mat
,
arg_out
});
ctx
->
mark_non_differentiable
({
arg_out
});
return
{
out
,
arg_out
};
}
static
variable_list
backward
(
AutogradContext
*
ctx
,
variable_list
grad_outs
)
{
auto
has_value
=
ctx
->
saved_data
[
"has_value"
].
toBool
();
auto
grad_out
=
grad_outs
[
0
];
auto
saved
=
ctx
->
get_saved_variables
();
auto
col
=
saved
[
0
],
value
=
saved
[
1
],
mat
=
saved
[
2
],
arg_out
=
saved
[
3
];
auto
invalid_arg_mask
=
arg_out
==
col
.
size
(
0
);
arg_out
=
arg_out
.
masked_fill
(
invalid_arg_mask
,
0
);
auto
grad_value
=
Variable
();
if
(
has_value
>
0
&&
torch
::
autograd
::
any_variable_requires_grad
({
value
}))
{
auto
ind
=
col
.
index_select
(
0
,
arg_out
.
flatten
()).
view_as
(
arg_out
);
auto
out
=
mat
.
gather
(
-
2
,
ind
);
out
.
mul_
(
grad_out
);
out
.
masked_fill_
(
invalid_arg_mask
,
0
);
grad_value
=
torch
::
zeros_like
(
value
);
grad_value
.
scatter_add_
(
0
,
arg_out
.
flatten
(),
out
.
flatten
());
}
auto
grad_mat
=
Variable
();
if
(
torch
::
autograd
::
any_variable_requires_grad
({
mat
}))
{
if
(
has_value
>
0
)
{
value
=
value
.
index_select
(
0
,
arg_out
.
flatten
()).
view_as
(
arg_out
);
value
.
mul_
(
grad_out
);
}
else
value
=
grad_out
;
value
.
masked_fill_
(
invalid_arg_mask
,
0
);
auto
ind
=
col
.
index_select
(
0
,
arg_out
.
flatten
()).
view_as
(
arg_out
);
grad_mat
=
torch
::
zeros_like
(
mat
);
grad_mat
.
scatter_add_
(
-
2
,
ind
,
value
);
}
return
{
Variable
(),
Variable
(),
grad_value
,
grad_mat
,
Variable
()};
}
};
torch
::
Tensor
spmm_sum
(
torch
::
optional
<
torch
::
Tensor
>
opt_row
,
torch
::
Tensor
spmm_sum
(
torch
::
optional
<
torch
::
Tensor
>
opt_row
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
opt_value
,
torch
::
optional
<
torch
::
Tensor
>
opt_value
,
...
@@ -191,6 +307,24 @@ torch::Tensor spmm_mean(torch::optional<torch::Tensor> opt_row,
...
@@ -191,6 +307,24 @@ torch::Tensor spmm_mean(torch::optional<torch::Tensor> opt_row,
opt_csr2csc
,
mat
,
opt_value
.
has_value
())[
0
];
opt_csr2csc
,
mat
,
opt_value
.
has_value
())[
0
];
}
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
spmm_min
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
opt_value
,
torch
::
Tensor
mat
)
{
auto
value
=
opt_value
.
has_value
()
?
opt_value
.
value
()
:
col
;
auto
result
=
SPMMMin
::
apply
(
rowptr
,
col
,
value
,
mat
,
opt_value
.
has_value
());
return
std
::
make_tuple
(
result
[
0
],
result
[
1
]);
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
spmm_max
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
opt_value
,
torch
::
Tensor
mat
)
{
auto
value
=
opt_value
.
has_value
()
?
opt_value
.
value
()
:
col
;
auto
result
=
SPMMMax
::
apply
(
rowptr
,
col
,
value
,
mat
,
opt_value
.
has_value
());
return
std
::
make_tuple
(
result
[
0
],
result
[
1
]);
}
static
auto
registry
=
torch
::
RegisterOperators
()
static
auto
registry
=
torch
::
RegisterOperators
()
.
op
(
"torch_sparse::spmm_sum"
,
&
spmm_sum
)
.
op
(
"torch_sparse::spmm_sum"
,
&
spmm_sum
)
.
op
(
"torch_sparse::spmm_mean"
,
&
spmm_mean
);
.
op
(
"torch_sparse::spmm_mean"
,
&
spmm_mean
)
.
op
(
"torch_sparse::spmm_min"
,
&
spmm_min
)
.
op
(
"torch_sparse::spmm_max"
,
&
spmm_max
);
test/test_matmul.py
View file @
e44a639f
...
@@ -7,10 +7,7 @@ from torch_sparse.matmul import matmul
...
@@ -7,10 +7,7 @@ from torch_sparse.matmul import matmul
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
import
torch_scatter
import
torch_scatter
from
.utils
import
devices
,
grad_dtypes
from
.utils
import
reductions
,
devices
,
grad_dtypes
reductions
=
[
'sum'
,
'mean'
,
'min'
,
'max'
]
reductions
=
[
'sum'
,
'mean'
]
@
pytest
.
mark
.
parametrize
(
'dtype,device,reduce'
,
@
pytest
.
mark
.
parametrize
(
'dtype,device,reduce'
,
...
...
test/utils.py
View file @
e44a639f
import
torch
import
torch
reductions
=
[
'sum'
,
'add'
,
'mean'
,
'min'
,
'max'
]
dtypes
=
[
torch
.
float
,
torch
.
double
,
torch
.
int
,
torch
.
long
]
dtypes
=
[
torch
.
float
,
torch
.
double
,
torch
.
int
,
torch
.
long
]
grad_dtypes
=
[
torch
.
float
,
torch
.
double
]
grad_dtypes
=
[
torch
.
float
,
torch
.
double
]
...
...
torch_sparse/__init__.py
View file @
e44a639f
...
@@ -46,4 +46,5 @@ from .diag import set_diag, remove_diag
...
@@ -46,4 +46,5 @@ from .diag import set_diag, remove_diag
from
.add
import
add
,
add_
,
add_nnz
,
add_nnz_
from
.add
import
add
,
add_
,
add_nnz
,
add_nnz_
from
.mul
import
mul
,
mul_
,
mul_nnz
,
mul_nnz_
from
.mul
import
mul
,
mul_
,
mul_nnz
,
mul_nnz_
from
.reduce
import
sum
,
mean
,
min
,
max
from
.reduce
import
sum
,
mean
,
min
,
max
from
.matmul
import
spmm_sum
,
spmm_add
,
spmm
,
matmul
from
.matmul
import
(
spmm_sum
,
spmm_add
,
spmm_mean
,
spmm_min
,
spmm_max
,
spmm
,
spspmm_sum
,
spspmm_add
,
spspmm
,
matmul
)
torch_sparse/matmul.py
View file @
e44a639f
import
warnings
import
warnings
import
os.path
as
osp
import
os.path
as
osp
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
,
Tuple
import
torch
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
...
@@ -29,8 +29,26 @@ except OSError:
...
@@ -29,8 +29,26 @@ except OSError:
raise
ImportError
raise
ImportError
return
mat
return
mat
def
spmm_min_max_placeholder
(
rowptr
:
torch
.
Tensor
,
col
:
torch
.
Tensor
,
value
:
Optional
[
torch
.
Tensor
],
mat
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
ImportError
return
mat
,
mat
def
spspmm_sum_placeholder
(
rowptrA
:
torch
.
Tensor
,
colA
:
torch
.
Tensor
,
valueA
:
Optional
[
torch
.
Tensor
],
rowptrB
:
torch
.
Tensor
,
colB
:
torch
.
Tensor
,
valueB
:
Optional
[
torch
.
Tensor
],
K
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
raise
ImportError
return
rowptrA
,
colA
,
valueA
torch
.
ops
.
torch_sparse
.
spmm_sum
=
spmm_sum_placeholder
torch
.
ops
.
torch_sparse
.
spmm_sum
=
spmm_sum_placeholder
torch
.
ops
.
torch_sparse
.
spmm_mean
=
spmm_mean_placeholder
torch
.
ops
.
torch_sparse
.
spmm_mean
=
spmm_mean_placeholder
torch
.
ops
.
torch_sparse
.
spmm_min
=
spmm_min_max_placeholder
torch
.
ops
.
torch_sparse
.
spmm_max
=
spmm_min_max_placeholder
torch
.
ops
.
torch_sparse
.
spspmm_sum
=
spspmm_sum_placeholder
@
torch
.
jit
.
script
@
torch
.
jit
.
script
...
@@ -80,6 +98,20 @@ def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
...
@@ -80,6 +98,20 @@ def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
colptr
,
csr2csc
,
other
)
colptr
,
csr2csc
,
other
)
@
torch
.
jit
.
script
def
spmm_min
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
rowptr
,
col
,
value
=
src
.
csr
()
return
torch
.
ops
.
torch_sparse
.
spmm_min
(
rowptr
,
col
,
value
,
other
)
@
torch
.
jit
.
script
def
spmm_max
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
rowptr
,
col
,
value
=
src
.
csr
()
return
torch
.
ops
.
torch_sparse
.
spmm_max
(
rowptr
,
col
,
value
,
other
)
@
torch
.
jit
.
script
@
torch
.
jit
.
script
def
spmm
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
,
def
spmm
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
,
reduce
:
str
=
"sum"
)
->
torch
.
Tensor
:
reduce
:
str
=
"sum"
)
->
torch
.
Tensor
:
...
@@ -87,6 +119,37 @@ def spmm(src: SparseTensor, other: torch.Tensor,
...
@@ -87,6 +119,37 @@ def spmm(src: SparseTensor, other: torch.Tensor,
return
spmm_sum
(
src
,
other
)
return
spmm_sum
(
src
,
other
)
elif
reduce
==
'mean'
:
elif
reduce
==
'mean'
:
return
spmm_mean
(
src
,
other
)
return
spmm_mean
(
src
,
other
)
elif
reduce
==
'min'
:
return
spmm_min
(
src
,
other
)[
0
]
elif
reduce
==
'max'
:
return
spmm_max
(
src
,
other
)[
0
]
else
:
raise
ValueError
@
torch
.
jit
.
script
def
spspmm_sum
(
src
:
SparseTensor
,
other
:
SparseTensor
)
->
SparseTensor
:
rowptrA
,
colA
,
valueA
=
src
.
csr
()
rowptrB
,
colB
,
valueB
=
other
.
csr
()
M
,
K
=
src
.
sparse_size
(
0
),
other
.
sparse_size
(
1
)
rowptrC
,
colC
,
valueC
=
torch
.
ops
.
torch_sparse
.
spspmm_sum
(
rowptrA
,
colA
,
valueA
,
rowptrB
,
colB
,
valueB
,
K
)
return
SparseTensor
(
row
=
None
,
rowptr
=
rowptrC
,
col
=
colC
,
value
=
valueC
,
sparse_sizes
=
torch
.
Size
([
M
,
K
]),
is_sorted
=
True
)
@
torch
.
jit
.
script
def
spspmm_add
(
src
:
SparseTensor
,
other
:
SparseTensor
)
->
SparseTensor
:
return
spspmm_sum
(
src
,
other
)
@
torch
.
jit
.
script
def
spspmm
(
src
:
SparseTensor
,
other
:
SparseTensor
,
reduce
:
str
=
"sum"
)
->
SparseTensor
:
if
reduce
==
'sum'
or
reduce
==
'add'
:
return
spspmm_sum
(
src
,
other
)
elif
reduce
==
'mean'
or
reduce
==
'min'
or
reduce
==
'max'
:
raise
NotImplementedError
else
:
else
:
raise
ValueError
raise
ValueError
...
@@ -95,172 +158,15 @@ def matmul(src: SparseTensor, other: Union[torch.Tensor, SparseTensor],
...
@@ -95,172 +158,15 @@ def matmul(src: SparseTensor, other: Union[torch.Tensor, SparseTensor],
reduce
:
str
=
"sum"
):
reduce
:
str
=
"sum"
):
if
torch
.
is_tensor
(
other
):
if
torch
.
is_tensor
(
other
):
return
spmm
(
src
,
other
,
reduce
)
return
spmm
(
src
,
other
,
reduce
)
elif
isinstance
(
other
,
SparseTensor
):
return
spspmm
(
src
,
other
,
reduce
)
else
:
else
:
raise
ValueError
raise
ValueError
SparseTensor
.
spmm
=
lambda
self
,
other
,
reduce
=
None
:
spmm
(
self
,
other
,
reduce
)
SparseTensor
.
spmm
=
lambda
self
,
other
,
reduce
=
None
:
spmm
(
self
,
other
,
reduce
)
SparseTensor
.
spspmm
=
lambda
self
,
other
,
reduce
=
None
:
spspmm
(
self
,
other
,
reduce
)
SparseTensor
.
matmul
=
lambda
self
,
other
,
reduce
=
None
:
matmul
(
SparseTensor
.
matmul
=
lambda
self
,
other
,
reduce
=
None
:
matmul
(
self
,
other
,
reduce
)
self
,
other
,
reduce
)
SparseTensor
.
__matmul__
=
lambda
self
,
other
:
matmul
(
self
,
other
,
'sum'
)
SparseTensor
.
__matmul__
=
lambda
self
,
other
:
matmul
(
self
,
other
,
'sum'
)
# class SPMM(torch.autograd.Function):
# @staticmethod
# def forward(ctx, row, rowptr, col, value, mat, rowcount, colptr, csr2csc,
# reduce):
# if mat.is_cuda:
# out, arg_out = torch.ops.torch_sparse_cuda.spmm(
# rowptr, col, value, mat, reduce)
# else:
# out, arg_out = torch.ops.torch_sparse_cpu.spmm(
# rowptr, col, value, mat, reduce)
# ctx.reduce = reduce
# ctx.save_for_backward(row, rowptr, col, value, mat, rowcount, colptr,
# csr2csc, arg_out)
# if reduce == 'min' or reduce == 'max':
# ctx.mark_non_differentiable(arg_out)
# return out, arg_out
# else:
# return out
# @staticmethod
# def backward(ctx, grad_out, *args):
# (row, rowptr, col, value, mat, rowcount, colptr, csr2csc,
# arg_out) = ctx.saved_tensors
# invalid_arg_mask = arg_out_ind = None
# if ctx.reduce in ['min', 'max'] and (ctx.needs_input_grad[3]
# or ctx.needs_input_grad[4]):
# invalid_arg_mask = arg_out == col.size(0)
# arg_out_ind = arg_out.masked_fill(invalid_arg_mask, -1)
# grad_value = None
# if ctx.needs_input_grad[3]:
# if ctx.reduce in ['sum', 'add', 'mean']:
# grad_value = ext(grad_out.is_cuda).spmm_val_bw(
# row, rowptr, col, mat, grad_out, ctx.reduce)
# elif ctx.reduce in ['min', 'max']:
# col_tmp = col[arg_out_ind.flatten()].view_as(arg_out)
# out = mat.gather(-2, col_tmp).mul_(grad_out)
# out.masked_fill_(invalid_arg_mask, 0)
# grad_value = scatter_add(out.flatten(), arg_out.flatten(),
# dim=0, dim_size=value.numel() + 1)
# grad_value = grad_value[:-1]
# grad_mat = None
# if ctx.needs_input_grad[4]:
# if ctx.reduce in ['sum', 'add']:
# value = value[csr2csc] if value is not None else value
# grad_mat, _ = ext(grad_out.is_cuda).spmm(
# colptr, row[csr2csc], value, grad_out, 'sum')
# elif ctx.reduce == 'mean':
# count = rowcount[row].to(mat.dtype).clamp_(min=1)
# value = count.pow_(-1) if value is None else value / count
# row = row[csr2csc]
# value = value[csr2csc] if value is not None else value
# grad_mat, _ = ext(grad_out.is_cuda).spmm(
# colptr, row, value, grad_out, 'sum')
# elif ctx.reduce in ['min', 'max']:
# if value is not None:
# value = value[arg_out_ind.flatten()].view_as(arg_out)
# value = value.mul_(grad_out)
# else:
# value = grad_out
# value.masked_fill_(invalid_arg_mask, 0)
# col_tmp = col[arg_out_ind.flatten()].view_as(arg_out)
# grad_mat = scatter_add(value, col_tmp, dim=-2,
# dim_size=mat.size(-2))
# return None, None, None, grad_value, grad_mat, None, None, None, None
# class SPSPMM(torch.autograd.Function):
# @staticmethod
# def forward(ctx, rowptrA, colA, valueA, rowptrB, colB, valueB, M, N, K):
# if rowptrA.is_cuda:
# rowptrC, colC, valueC = ext(True).spspmm(rowptrA, colA, valueA,
# rowptrB, colB, valueB, M,
# N, K)
# else:
# dtype = None
# if valueA is not None:
# dtype = valueA.dtype
# if valueB is not None:
# dtype = valueB.dtype
# if valueA is None:
# valueA = torch.ones(colA.numel(), dtype=dtype)
# A = scipy.sparse.csr_matrix((valueA, colA, rowptrA), (M, N))
# if valueB is None:
# valueB = torch.ones(colB.numel(), dtype=dtype)
# B = scipy.sparse.csr_matrix((valueB, colB, rowptrB), (N, K))
# C = A @ B
# rowptrC = torch.from_numpy(C.indptr).to(torch.int64)
# colC = torch.from_numpy(C.indices).to(torch.int64)
# valueC = torch.from_numpy(C.data)
# valueC = valueC.to(dtype) if dtype is not None else None
# ctx.mark_non_differentiable(rowptrC, colC)
# # We cannot return `NoneType` in torch.autograd :(
# if valueC is None:
# return rowptrC, colC
# else:
# return rowptrC, colC, valueC
# @staticmethod
# def backward(ctx, grad_indexC, grad_rowptrC, *args):
# grad_valueA = None
# if ctx.needs_input_grad[2]:
# raise NotImplementedError
# grad_valueB = None
# if ctx.needs_input_grad[5]:
# raise NotImplementedError
# return (None, None, grad_valueA, None, None, grad_valueB, None, None,
# None)
# def matmul(src, other, reduce='sum'):
# assert src.dim() == 2 and src.size(-1) == other.size(-2)
# # Sparse-Dense Matrix Multiplication.
# if torch.is_tensor(other):
# assert reduce in ['sum', 'add', 'mean', 'min', 'max']
# rowptr, col, value = src.csr()
# row = None
# if reduce in ['sum', 'add', 'mean'] and (src.requires_grad
# or other.requires_grad):
# row = src.storage.row
# rowcount = None
# if other.requires_grad and reduce in ['mean']:
# rowcount = src.storage.rowcount
# csr2csc = colptr = None
# if other.requires_grad and reduce in ['sum', 'add', 'mean']:
# csr2csc, colptr = src.storage.csr2csc, src.storage.colptr
# return SPMM.apply(row, rowptr, col, value, other, rowcount, colptr,
# csr2csc, reduce)
# # Sparse-Sparse Matrix Multiplication.
# elif isinstance(other, src.__class__):
# assert reduce in ['sum', 'add']
# assert src.dim() == 2 and other.dim() == 2
# data = SPSPMM.apply(*src.csr(), *other.csr(), src.size(0), src.size(1),
# other.size(1))
# (rowptr, col), value = data[:2], data[2] if len(data) == 3 else None
# sparse_size = torch.Size([src.size(0), other.size(1)])
# return src.__class__(rowptr=rowptr, col=col, value=value,
# sparse_size=sparse_size, is_sorted=True)
# raise ValueError
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment