Commit bb1ba6b0 authored by rusty1s's avatar rusty1s
Browse files

fixed no value

parent 28cb8de4
......@@ -46,37 +46,48 @@ public:
Variable rowptr, Variable col, Variable value,
torch::optional<Variable> opt_colptr,
torch::optional<Variable> opt_csr2csc,
Variable mat) {
auto row = opt_row.has_value() ? opt_row.value() : torch::Tensor();
auto colptr = opt_colptr.has_value() ? opt_colptr.value() : torch::Tensor();
auto csr2csc =
opt_csr2csc.has_value() ? opt_csr2csc.value() : torch::Tensor();
Variable mat, bool has_value) {
if (has_value && torch::autograd::any_variable_requires_grad({value})) {
AT_ASSERTM(opt_row.has_value(), "Argument `row` is missing");
}
if (torch::autograd::any_variable_requires_grad({mat})) {
AT_ASSERTM(opt_row.has_value(), "Argument `row` is missing");
AT_ASSERTM(opt_colptr.has_value(), "Argument `colptr` is missing");
AT_ASSERTM(opt_csr2csc.has_value(), "Argument `csr2csc` is missing");
}
auto row = opt_row.has_value() ? opt_row.value() : col;
auto colptr = opt_colptr.has_value() ? opt_colptr.value() : col;
auto csr2csc = opt_csr2csc.has_value() ? opt_csr2csc.value() : col;
torch::optional<torch::Tensor> opt_value = torch::nullopt;
if (value.numel() > 0)
if (has_value)
opt_value = value;
auto out = std::get<0>(spmm_fw(rowptr, col, opt_value, mat, "sum"));
ctx->saved_data["has_value"] = has_value;
ctx->save_for_backward({row, rowptr, col, value, colptr, csr2csc, mat});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto has_value = ctx->saved_data["has_value"].toBool();
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto row = saved[0], rowptr = saved[1], col = saved[2], value = saved[3],
colptr = saved[4], csr2csc = saved[5], mat = saved[6];
auto grad_value = Variable();
if (value.numel() > 0 &&
torch::autograd::any_variable_requires_grad({value})) {
if (has_value > 0 && torch::autograd::any_variable_requires_grad({value})) {
grad_value = spmm_value_bw(row, rowptr, col, mat, grad_out, "sum");
}
auto grad_mat = Variable();
if (torch::autograd::any_variable_requires_grad({mat})) {
torch::optional<torch::Tensor> opt_value = torch::nullopt;
if (value.numel() > 0)
if (has_value)
opt_value = value.index_select(0, csr2csc);
grad_mat = std::get<0>(spmm_fw(colptr, row.index_select(0, csr2csc),
......@@ -84,7 +95,7 @@ public:
}
return {Variable(), Variable(), Variable(), grad_value,
Variable(), Variable(), grad_mat};
Variable(), Variable(), grad_mat, Variable()};
}
};
......@@ -96,25 +107,37 @@ public:
torch::optional<Variable> opt_rowcount,
torch::optional<Variable> opt_colptr,
torch::optional<Variable> opt_csr2csc,
Variable mat) {
auto row = opt_row.has_value() ? opt_row.value() : torch::Tensor();
auto rowcount =
opt_rowcount.has_value() ? opt_rowcount.value() : torch::Tensor();
auto colptr = opt_colptr.has_value() ? opt_colptr.value() : torch::Tensor();
auto csr2csc =
opt_csr2csc.has_value() ? opt_csr2csc.value() : torch::Tensor();
Variable mat, bool has_value) {
if (has_value && torch::autograd::any_variable_requires_grad({value})) {
AT_ASSERTM(opt_row.has_value(), "Argument `row` is missing");
}
if (torch::autograd::any_variable_requires_grad({mat})) {
AT_ASSERTM(opt_row.has_value(), "Argument `row` is missing");
AT_ASSERTM(opt_rowcount.has_value(), "Argument `rowcount` is missing");
AT_ASSERTM(opt_colptr.has_value(), "Argument `colptr` is missing");
AT_ASSERTM(opt_csr2csc.has_value(), "Argument `csr2csc` is missing");
}
auto row = opt_row.has_value() ? opt_row.value() : col;
auto rowcount = opt_rowcount.has_value() ? opt_rowcount.value() : col;
auto colptr = opt_colptr.has_value() ? opt_colptr.value() : col;
auto csr2csc = opt_csr2csc.has_value() ? opt_csr2csc.value() : col;
torch::optional<torch::Tensor> opt_value = torch::nullopt;
if (value.numel() > 0)
if (has_value)
opt_value = value;
auto out = std::get<0>(spmm_fw(rowptr, col, opt_value, mat, "mean"));
ctx->saved_data["has_value"] = has_value;
ctx->save_for_backward(
{row, rowptr, col, value, rowcount, colptr, csr2csc, mat});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto has_value = ctx->saved_data["has_value"].toBool();
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto row = saved[0], rowptr = saved[1], col = saved[2], value = saved[3],
......@@ -122,8 +145,7 @@ public:
mat = saved[7];
auto grad_value = Variable();
if (value.numel() > 0 &&
torch::autograd::any_variable_requires_grad({value})) {
if (has_value > 0 && torch::autograd::any_variable_requires_grad({value})) {
grad_value = spmm_value_bw(row, rowptr, col, mat, grad_out, "mean");
}
......@@ -133,7 +155,7 @@ public:
rowcount = rowcount.toType(mat.scalar_type()).index_select(0, row);
rowcount.clamp_(1);
if (value.numel() > 0)
if (has_value > 0)
rowcount = value.index_select(0, csr2csc).div(rowcount);
else
rowcount.pow_(-1);
......@@ -141,8 +163,8 @@ public:
grad_mat = std::get<0>(spmm_fw(colptr, row, rowcount, grad_out, "sum"));
}
return {Variable(), Variable(), Variable(), grad_value,
Variable(), Variable(), Variable(), grad_mat};
return {Variable(), Variable(), Variable(), grad_value, Variable(),
Variable(), Variable(), grad_mat, Variable()};
}
};
......@@ -152,11 +174,9 @@ torch::Tensor spmm_sum(torch::optional<torch::Tensor> opt_row,
torch::optional<torch::Tensor> opt_colptr,
torch::optional<torch::Tensor> opt_csr2csc,
torch::Tensor mat) {
// Since we cannot return an *optional* gradient, we need to convert
// `opt_value` to an empty sized tensor first :(
auto value = opt_value.has_value() ? opt_value.value() : torch::Tensor();
auto value = opt_value.has_value() ? opt_value.value() : col;
return SPMMSum::apply(opt_row, rowptr, col, value, opt_colptr, opt_csr2csc,
mat)[0];
mat, opt_value.has_value())[0];
}
torch::Tensor spmm_mean(torch::optional<torch::Tensor> opt_row,
......@@ -166,9 +186,9 @@ torch::Tensor spmm_mean(torch::optional<torch::Tensor> opt_row,
torch::optional<torch::Tensor> opt_colptr,
torch::optional<torch::Tensor> opt_csr2csc,
torch::Tensor mat) {
auto value = opt_value.has_value() ? opt_value.value() : torch::Tensor();
auto value = opt_value.has_value() ? opt_value.value() : col;
return SPMMMean::apply(opt_row, rowptr, col, value, opt_rowcount, opt_colptr,
opt_csr2csc, mat)[0];
opt_csr2csc, mat, opt_value.has_value())[0];
}
static auto registry = torch::RegisterOperators()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment