Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
bb1ba6b0
Commit
bb1ba6b0
authored
Feb 02, 2020
by
rusty1s
Browse files
fixed no value
parent
28cb8de4
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
49 additions
and
29 deletions
+49
-29
csrc/spmm.cpp
csrc/spmm.cpp
+49
-29
No files found.
csrc/spmm.cpp
View file @
bb1ba6b0
...
...
@@ -46,37 +46,48 @@ public:
Variable
rowptr
,
Variable
col
,
Variable
value
,
torch
::
optional
<
Variable
>
opt_colptr
,
torch
::
optional
<
Variable
>
opt_csr2csc
,
Variable
mat
)
{
auto
row
=
opt_row
.
has_value
()
?
opt_row
.
value
()
:
torch
::
Tensor
();
auto
colptr
=
opt_colptr
.
has_value
()
?
opt_colptr
.
value
()
:
torch
::
Tensor
();
auto
csr2csc
=
opt_csr2csc
.
has_value
()
?
opt_csr2csc
.
value
()
:
torch
::
Tensor
();
Variable
mat
,
bool
has_value
)
{
if
(
has_value
&&
torch
::
autograd
::
any_variable_requires_grad
({
value
}))
{
AT_ASSERTM
(
opt_row
.
has_value
(),
"Argument `row` is missing"
);
}
if
(
torch
::
autograd
::
any_variable_requires_grad
({
mat
}))
{
AT_ASSERTM
(
opt_row
.
has_value
(),
"Argument `row` is missing"
);
AT_ASSERTM
(
opt_colptr
.
has_value
(),
"Argument `colptr` is missing"
);
AT_ASSERTM
(
opt_csr2csc
.
has_value
(),
"Argument `csr2csc` is missing"
);
}
auto
row
=
opt_row
.
has_value
()
?
opt_row
.
value
()
:
col
;
auto
colptr
=
opt_colptr
.
has_value
()
?
opt_colptr
.
value
()
:
col
;
auto
csr2csc
=
opt_csr2csc
.
has_value
()
?
opt_csr2csc
.
value
()
:
col
;
torch
::
optional
<
torch
::
Tensor
>
opt_value
=
torch
::
nullopt
;
if
(
value
.
numel
()
>
0
)
if
(
has_
value
)
opt_value
=
value
;
auto
out
=
std
::
get
<
0
>
(
spmm_fw
(
rowptr
,
col
,
opt_value
,
mat
,
"sum"
));
ctx
->
saved_data
[
"has_value"
]
=
has_value
;
ctx
->
save_for_backward
({
row
,
rowptr
,
col
,
value
,
colptr
,
csr2csc
,
mat
});
return
{
out
};
}
static
variable_list
backward
(
AutogradContext
*
ctx
,
variable_list
grad_outs
)
{
auto
has_value
=
ctx
->
saved_data
[
"has_value"
].
toBool
();
auto
grad_out
=
grad_outs
[
0
];
auto
saved
=
ctx
->
get_saved_variables
();
auto
row
=
saved
[
0
],
rowptr
=
saved
[
1
],
col
=
saved
[
2
],
value
=
saved
[
3
],
colptr
=
saved
[
4
],
csr2csc
=
saved
[
5
],
mat
=
saved
[
6
];
auto
grad_value
=
Variable
();
if
(
value
.
numel
()
>
0
&&
torch
::
autograd
::
any_variable_requires_grad
({
value
}))
{
if
(
has_value
>
0
&&
torch
::
autograd
::
any_variable_requires_grad
({
value
}))
{
grad_value
=
spmm_value_bw
(
row
,
rowptr
,
col
,
mat
,
grad_out
,
"sum"
);
}
auto
grad_mat
=
Variable
();
if
(
torch
::
autograd
::
any_variable_requires_grad
({
mat
}))
{
torch
::
optional
<
torch
::
Tensor
>
opt_value
=
torch
::
nullopt
;
if
(
value
.
numel
()
>
0
)
if
(
has_
value
)
opt_value
=
value
.
index_select
(
0
,
csr2csc
);
grad_mat
=
std
::
get
<
0
>
(
spmm_fw
(
colptr
,
row
.
index_select
(
0
,
csr2csc
),
...
...
@@ -84,7 +95,7 @@ public:
}
return
{
Variable
(),
Variable
(),
Variable
(),
grad_value
,
Variable
(),
Variable
(),
grad_mat
};
Variable
(),
Variable
(),
grad_mat
,
Variable
()
};
}
};
...
...
@@ -96,25 +107,37 @@ public:
torch
::
optional
<
Variable
>
opt_rowcount
,
torch
::
optional
<
Variable
>
opt_colptr
,
torch
::
optional
<
Variable
>
opt_csr2csc
,
Variable
mat
)
{
auto
row
=
opt_row
.
has_value
()
?
opt_row
.
value
()
:
torch
::
Tensor
();
auto
rowcount
=
opt_rowcount
.
has_value
()
?
opt_rowcount
.
value
()
:
torch
::
Tensor
();
auto
colptr
=
opt_colptr
.
has_value
()
?
opt_colptr
.
value
()
:
torch
::
Tensor
();
auto
csr2csc
=
opt_csr2csc
.
has_value
()
?
opt_csr2csc
.
value
()
:
torch
::
Tensor
();
Variable
mat
,
bool
has_value
)
{
if
(
has_value
&&
torch
::
autograd
::
any_variable_requires_grad
({
value
}))
{
AT_ASSERTM
(
opt_row
.
has_value
(),
"Argument `row` is missing"
);
}
if
(
torch
::
autograd
::
any_variable_requires_grad
({
mat
}))
{
AT_ASSERTM
(
opt_row
.
has_value
(),
"Argument `row` is missing"
);
AT_ASSERTM
(
opt_rowcount
.
has_value
(),
"Argument `rowcount` is missing"
);
AT_ASSERTM
(
opt_colptr
.
has_value
(),
"Argument `colptr` is missing"
);
AT_ASSERTM
(
opt_csr2csc
.
has_value
(),
"Argument `csr2csc` is missing"
);
}
auto
row
=
opt_row
.
has_value
()
?
opt_row
.
value
()
:
col
;
auto
rowcount
=
opt_rowcount
.
has_value
()
?
opt_rowcount
.
value
()
:
col
;
auto
colptr
=
opt_colptr
.
has_value
()
?
opt_colptr
.
value
()
:
col
;
auto
csr2csc
=
opt_csr2csc
.
has_value
()
?
opt_csr2csc
.
value
()
:
col
;
torch
::
optional
<
torch
::
Tensor
>
opt_value
=
torch
::
nullopt
;
if
(
value
.
numel
()
>
0
)
if
(
has_
value
)
opt_value
=
value
;
auto
out
=
std
::
get
<
0
>
(
spmm_fw
(
rowptr
,
col
,
opt_value
,
mat
,
"mean"
));
ctx
->
saved_data
[
"has_value"
]
=
has_value
;
ctx
->
save_for_backward
(
{
row
,
rowptr
,
col
,
value
,
rowcount
,
colptr
,
csr2csc
,
mat
});
return
{
out
};
}
static
variable_list
backward
(
AutogradContext
*
ctx
,
variable_list
grad_outs
)
{
auto
has_value
=
ctx
->
saved_data
[
"has_value"
].
toBool
();
auto
grad_out
=
grad_outs
[
0
];
auto
saved
=
ctx
->
get_saved_variables
();
auto
row
=
saved
[
0
],
rowptr
=
saved
[
1
],
col
=
saved
[
2
],
value
=
saved
[
3
],
...
...
@@ -122,8 +145,7 @@ public:
mat
=
saved
[
7
];
auto
grad_value
=
Variable
();
if
(
value
.
numel
()
>
0
&&
torch
::
autograd
::
any_variable_requires_grad
({
value
}))
{
if
(
has_value
>
0
&&
torch
::
autograd
::
any_variable_requires_grad
({
value
}))
{
grad_value
=
spmm_value_bw
(
row
,
rowptr
,
col
,
mat
,
grad_out
,
"mean"
);
}
...
...
@@ -133,7 +155,7 @@ public:
rowcount
=
rowcount
.
toType
(
mat
.
scalar_type
()).
index_select
(
0
,
row
);
rowcount
.
clamp_
(
1
);
if
(
value
.
numel
()
>
0
)
if
(
has_
value
>
0
)
rowcount
=
value
.
index_select
(
0
,
csr2csc
).
div
(
rowcount
);
else
rowcount
.
pow_
(
-
1
);
...
...
@@ -141,8 +163,8 @@ public:
grad_mat
=
std
::
get
<
0
>
(
spmm_fw
(
colptr
,
row
,
rowcount
,
grad_out
,
"sum"
));
}
return
{
Variable
(),
Variable
(),
Variable
(),
grad_value
,
Variable
(),
Variable
(),
Variable
()
,
grad_mat
};
return
{
Variable
(),
Variable
(),
Variable
(),
grad_value
,
Variable
(),
Variable
(),
Variable
(),
grad_mat
,
Variable
()};
}
};
...
...
@@ -152,11 +174,9 @@ torch::Tensor spmm_sum(torch::optional<torch::Tensor> opt_row,
torch
::
optional
<
torch
::
Tensor
>
opt_colptr
,
torch
::
optional
<
torch
::
Tensor
>
opt_csr2csc
,
torch
::
Tensor
mat
)
{
// Since we cannot return an *optional* gradient, we need to convert
// `opt_value` to an empty sized tensor first :(
auto
value
=
opt_value
.
has_value
()
?
opt_value
.
value
()
:
torch
::
Tensor
();
auto
value
=
opt_value
.
has_value
()
?
opt_value
.
value
()
:
col
;
return
SPMMSum
::
apply
(
opt_row
,
rowptr
,
col
,
value
,
opt_colptr
,
opt_csr2csc
,
mat
)[
0
];
mat
,
opt_value
.
has_value
()
)[
0
];
}
torch
::
Tensor
spmm_mean
(
torch
::
optional
<
torch
::
Tensor
>
opt_row
,
...
...
@@ -166,9 +186,9 @@ torch::Tensor spmm_mean(torch::optional<torch::Tensor> opt_row,
torch
::
optional
<
torch
::
Tensor
>
opt_colptr
,
torch
::
optional
<
torch
::
Tensor
>
opt_csr2csc
,
torch
::
Tensor
mat
)
{
auto
value
=
opt_value
.
has_value
()
?
opt_value
.
value
()
:
torch
::
Tensor
()
;
auto
value
=
opt_value
.
has_value
()
?
opt_value
.
value
()
:
col
;
return
SPMMMean
::
apply
(
opt_row
,
rowptr
,
col
,
value
,
opt_rowcount
,
opt_colptr
,
opt_csr2csc
,
mat
)[
0
];
opt_csr2csc
,
mat
,
opt_value
.
has_value
()
)[
0
];
}
static
auto
registry
=
torch
::
RegisterOperators
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment