Commit bb1ba6b0 authored by rusty1s's avatar rusty1s
Browse files

fixed no value

parent 28cb8de4
...@@ -46,37 +46,48 @@ public: ...@@ -46,37 +46,48 @@ public:
Variable rowptr, Variable col, Variable value, Variable rowptr, Variable col, Variable value,
torch::optional<Variable> opt_colptr, torch::optional<Variable> opt_colptr,
torch::optional<Variable> opt_csr2csc, torch::optional<Variable> opt_csr2csc,
Variable mat) { Variable mat, bool has_value) {
auto row = opt_row.has_value() ? opt_row.value() : torch::Tensor();
auto colptr = opt_colptr.has_value() ? opt_colptr.value() : torch::Tensor(); if (has_value && torch::autograd::any_variable_requires_grad({value})) {
auto csr2csc = AT_ASSERTM(opt_row.has_value(), "Argument `row` is missing");
opt_csr2csc.has_value() ? opt_csr2csc.value() : torch::Tensor(); }
if (torch::autograd::any_variable_requires_grad({mat})) {
AT_ASSERTM(opt_row.has_value(), "Argument `row` is missing");
AT_ASSERTM(opt_colptr.has_value(), "Argument `colptr` is missing");
AT_ASSERTM(opt_csr2csc.has_value(), "Argument `csr2csc` is missing");
}
auto row = opt_row.has_value() ? opt_row.value() : col;
auto colptr = opt_colptr.has_value() ? opt_colptr.value() : col;
auto csr2csc = opt_csr2csc.has_value() ? opt_csr2csc.value() : col;
torch::optional<torch::Tensor> opt_value = torch::nullopt; torch::optional<torch::Tensor> opt_value = torch::nullopt;
if (value.numel() > 0) if (has_value)
opt_value = value; opt_value = value;
auto out = std::get<0>(spmm_fw(rowptr, col, opt_value, mat, "sum")); auto out = std::get<0>(spmm_fw(rowptr, col, opt_value, mat, "sum"));
ctx->saved_data["has_value"] = has_value;
ctx->save_for_backward({row, rowptr, col, value, colptr, csr2csc, mat}); ctx->save_for_backward({row, rowptr, col, value, colptr, csr2csc, mat});
return {out}; return {out};
} }
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto has_value = ctx->saved_data["has_value"].toBool();
auto grad_out = grad_outs[0]; auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables(); auto saved = ctx->get_saved_variables();
auto row = saved[0], rowptr = saved[1], col = saved[2], value = saved[3], auto row = saved[0], rowptr = saved[1], col = saved[2], value = saved[3],
colptr = saved[4], csr2csc = saved[5], mat = saved[6]; colptr = saved[4], csr2csc = saved[5], mat = saved[6];
auto grad_value = Variable(); auto grad_value = Variable();
if (value.numel() > 0 && if (has_value > 0 && torch::autograd::any_variable_requires_grad({value})) {
torch::autograd::any_variable_requires_grad({value})) {
grad_value = spmm_value_bw(row, rowptr, col, mat, grad_out, "sum"); grad_value = spmm_value_bw(row, rowptr, col, mat, grad_out, "sum");
} }
auto grad_mat = Variable(); auto grad_mat = Variable();
if (torch::autograd::any_variable_requires_grad({mat})) { if (torch::autograd::any_variable_requires_grad({mat})) {
torch::optional<torch::Tensor> opt_value = torch::nullopt; torch::optional<torch::Tensor> opt_value = torch::nullopt;
if (value.numel() > 0) if (has_value)
opt_value = value.index_select(0, csr2csc); opt_value = value.index_select(0, csr2csc);
grad_mat = std::get<0>(spmm_fw(colptr, row.index_select(0, csr2csc), grad_mat = std::get<0>(spmm_fw(colptr, row.index_select(0, csr2csc),
...@@ -84,7 +95,7 @@ public: ...@@ -84,7 +95,7 @@ public:
} }
return {Variable(), Variable(), Variable(), grad_value, return {Variable(), Variable(), Variable(), grad_value,
Variable(), Variable(), grad_mat}; Variable(), Variable(), grad_mat, Variable()};
} }
}; };
...@@ -96,25 +107,37 @@ public: ...@@ -96,25 +107,37 @@ public:
torch::optional<Variable> opt_rowcount, torch::optional<Variable> opt_rowcount,
torch::optional<Variable> opt_colptr, torch::optional<Variable> opt_colptr,
torch::optional<Variable> opt_csr2csc, torch::optional<Variable> opt_csr2csc,
Variable mat) { Variable mat, bool has_value) {
auto row = opt_row.has_value() ? opt_row.value() : torch::Tensor();
auto rowcount = if (has_value && torch::autograd::any_variable_requires_grad({value})) {
opt_rowcount.has_value() ? opt_rowcount.value() : torch::Tensor(); AT_ASSERTM(opt_row.has_value(), "Argument `row` is missing");
auto colptr = opt_colptr.has_value() ? opt_colptr.value() : torch::Tensor(); }
auto csr2csc =
opt_csr2csc.has_value() ? opt_csr2csc.value() : torch::Tensor(); if (torch::autograd::any_variable_requires_grad({mat})) {
AT_ASSERTM(opt_row.has_value(), "Argument `row` is missing");
AT_ASSERTM(opt_rowcount.has_value(), "Argument `rowcount` is missing");
AT_ASSERTM(opt_colptr.has_value(), "Argument `colptr` is missing");
AT_ASSERTM(opt_csr2csc.has_value(), "Argument `csr2csc` is missing");
}
auto row = opt_row.has_value() ? opt_row.value() : col;
auto rowcount = opt_rowcount.has_value() ? opt_rowcount.value() : col;
auto colptr = opt_colptr.has_value() ? opt_colptr.value() : col;
auto csr2csc = opt_csr2csc.has_value() ? opt_csr2csc.value() : col;
torch::optional<torch::Tensor> opt_value = torch::nullopt; torch::optional<torch::Tensor> opt_value = torch::nullopt;
if (value.numel() > 0) if (has_value)
opt_value = value; opt_value = value;
auto out = std::get<0>(spmm_fw(rowptr, col, opt_value, mat, "mean")); auto out = std::get<0>(spmm_fw(rowptr, col, opt_value, mat, "mean"));
ctx->saved_data["has_value"] = has_value;
ctx->save_for_backward( ctx->save_for_backward(
{row, rowptr, col, value, rowcount, colptr, csr2csc, mat}); {row, rowptr, col, value, rowcount, colptr, csr2csc, mat});
return {out}; return {out};
} }
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto has_value = ctx->saved_data["has_value"].toBool();
auto grad_out = grad_outs[0]; auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables(); auto saved = ctx->get_saved_variables();
auto row = saved[0], rowptr = saved[1], col = saved[2], value = saved[3], auto row = saved[0], rowptr = saved[1], col = saved[2], value = saved[3],
...@@ -122,8 +145,7 @@ public: ...@@ -122,8 +145,7 @@ public:
mat = saved[7]; mat = saved[7];
auto grad_value = Variable(); auto grad_value = Variable();
if (value.numel() > 0 && if (has_value > 0 && torch::autograd::any_variable_requires_grad({value})) {
torch::autograd::any_variable_requires_grad({value})) {
grad_value = spmm_value_bw(row, rowptr, col, mat, grad_out, "mean"); grad_value = spmm_value_bw(row, rowptr, col, mat, grad_out, "mean");
} }
...@@ -133,7 +155,7 @@ public: ...@@ -133,7 +155,7 @@ public:
rowcount = rowcount.toType(mat.scalar_type()).index_select(0, row); rowcount = rowcount.toType(mat.scalar_type()).index_select(0, row);
rowcount.clamp_(1); rowcount.clamp_(1);
if (value.numel() > 0) if (has_value > 0)
rowcount = value.index_select(0, csr2csc).div(rowcount); rowcount = value.index_select(0, csr2csc).div(rowcount);
else else
rowcount.pow_(-1); rowcount.pow_(-1);
...@@ -141,8 +163,8 @@ public: ...@@ -141,8 +163,8 @@ public:
grad_mat = std::get<0>(spmm_fw(colptr, row, rowcount, grad_out, "sum")); grad_mat = std::get<0>(spmm_fw(colptr, row, rowcount, grad_out, "sum"));
} }
return {Variable(), Variable(), Variable(), grad_value, return {Variable(), Variable(), Variable(), grad_value, Variable(),
Variable(), Variable(), Variable(), grad_mat}; Variable(), Variable(), grad_mat, Variable()};
} }
}; };
...@@ -152,11 +174,9 @@ torch::Tensor spmm_sum(torch::optional<torch::Tensor> opt_row, ...@@ -152,11 +174,9 @@ torch::Tensor spmm_sum(torch::optional<torch::Tensor> opt_row,
torch::optional<torch::Tensor> opt_colptr, torch::optional<torch::Tensor> opt_colptr,
torch::optional<torch::Tensor> opt_csr2csc, torch::optional<torch::Tensor> opt_csr2csc,
torch::Tensor mat) { torch::Tensor mat) {
// Since we cannot return an *optional* gradient, we need to convert auto value = opt_value.has_value() ? opt_value.value() : col;
// `opt_value` to an empty sized tensor first :(
auto value = opt_value.has_value() ? opt_value.value() : torch::Tensor();
return SPMMSum::apply(opt_row, rowptr, col, value, opt_colptr, opt_csr2csc, return SPMMSum::apply(opt_row, rowptr, col, value, opt_colptr, opt_csr2csc,
mat)[0]; mat, opt_value.has_value())[0];
} }
torch::Tensor spmm_mean(torch::optional<torch::Tensor> opt_row, torch::Tensor spmm_mean(torch::optional<torch::Tensor> opt_row,
...@@ -166,9 +186,9 @@ torch::Tensor spmm_mean(torch::optional<torch::Tensor> opt_row, ...@@ -166,9 +186,9 @@ torch::Tensor spmm_mean(torch::optional<torch::Tensor> opt_row,
torch::optional<torch::Tensor> opt_colptr, torch::optional<torch::Tensor> opt_colptr,
torch::optional<torch::Tensor> opt_csr2csc, torch::optional<torch::Tensor> opt_csr2csc,
torch::Tensor mat) { torch::Tensor mat) {
auto value = opt_value.has_value() ? opt_value.value() : torch::Tensor(); auto value = opt_value.has_value() ? opt_value.value() : col;
return SPMMMean::apply(opt_row, rowptr, col, value, opt_rowcount, opt_colptr, return SPMMMean::apply(opt_row, rowptr, col, value, opt_rowcount, opt_colptr,
opt_csr2csc, mat)[0]; opt_csr2csc, mat, opt_value.has_value())[0];
} }
static auto registry = torch::RegisterOperators() static auto registry = torch::RegisterOperators()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment