Commit 7134d457 authored by rusty1s's avatar rusty1s
Browse files

fix index_bug half bug

parent 56ec830f
......@@ -34,7 +34,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Free up disk space
if: ${{ runner.os == 'Linux' && matrix.cuda-version == 'cu111' }}
if: ${{ runner.os == 'Linux' }}
run: |
sudo rm -rf /usr/share/dotnet
......
......@@ -30,4 +30,4 @@ echo "PyTorch $TORCH_VERSION+$CUDA_VERSION"
echo "- $CONDA_PYTORCH_CONSTRAINT"
echo "- $CONDA_CUDATOOLKIT_CONSTRAINT"
conda build . -c defaults -c nvidia -c pytorch -c conda-forge -c rusty1s --output-folder "$HOME/conda-bld"
conda build . -c pytorch -c nvidia -c rusty1s -c defaults -c conda-forge --output-folder "$HOME/conda-bld"
......@@ -97,7 +97,7 @@ public:
if (torch::autograd::any_variable_requires_grad({mat})) {
torch::optional<torch::Tensor> opt_value = torch::nullopt;
if (has_value)
opt_value = value.index_select(0, csr2csc);
opt_value = value.view({-1, 1}).index_select(0, csr2csc).view(-1);
grad_mat = std::get<0>(spmm_fw(colptr, row.index_select(0, csr2csc),
opt_value, grad_out, "sum"));
......@@ -161,11 +161,12 @@ public:
auto grad_mat = Variable();
if (torch::autograd::any_variable_requires_grad({mat})) {
row = row.index_select(0, csr2csc);
rowcount = rowcount.toType(mat.scalar_type()).index_select(0, row);
rowcount = rowcount.index_select(0, row).toType(mat.scalar_type());
rowcount.masked_fill_(rowcount < 1, 1);
if (has_value > 0)
rowcount = value.index_select(0, csr2csc).div(rowcount);
rowcount =
value.view({-1, 1}).index_select(0, csr2csc).view(-1).div(rowcount);
else
rowcount.pow_(-1);
......@@ -219,8 +220,10 @@ public:
auto grad_mat = Variable();
if (torch::autograd::any_variable_requires_grad({mat})) {
if (has_value > 0) {
value = value.index_select(0, arg_out.flatten()).view_as(arg_out);
value.mul_(grad_out);
value = value.view({-1, 1})
.index_select(0, arg_out.flatten())
.view_as(arg_out)
.mul_(grad_out);
} else
value = grad_out;
......@@ -277,8 +280,10 @@ public:
auto grad_mat = Variable();
if (torch::autograd::any_variable_requires_grad({mat})) {
if (has_value > 0) {
value = value.index_select(0, arg_out.flatten()).view_as(arg_out);
value.mul_(grad_out);
value = value.view({-1, 1})
.index_select(0, arg_out.flatten())
.view_as(arg_out)
.mul_(grad_out);
} else
value = grad_out;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment