Commit 7134d457 authored by rusty1s's avatar rusty1s
Browse files

fix index_bug half bug

parent 56ec830f
...@@ -34,7 +34,7 @@ jobs: ...@@ -34,7 +34,7 @@ jobs:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Free up disk space - name: Free up disk space
if: ${{ runner.os == 'Linux' && matrix.cuda-version == 'cu111' }} if: ${{ runner.os == 'Linux' }}
run: | run: |
sudo rm -rf /usr/share/dotnet sudo rm -rf /usr/share/dotnet
......
...@@ -30,4 +30,4 @@ echo "PyTorch $TORCH_VERSION+$CUDA_VERSION" ...@@ -30,4 +30,4 @@ echo "PyTorch $TORCH_VERSION+$CUDA_VERSION"
echo "- $CONDA_PYTORCH_CONSTRAINT" echo "- $CONDA_PYTORCH_CONSTRAINT"
echo "- $CONDA_CUDATOOLKIT_CONSTRAINT" echo "- $CONDA_CUDATOOLKIT_CONSTRAINT"
conda build . -c defaults -c nvidia -c pytorch -c conda-forge -c rusty1s --output-folder "$HOME/conda-bld" conda build . -c pytorch -c nvidia -c rusty1s -c defaults -c conda-forge --output-folder "$HOME/conda-bld"
...@@ -97,7 +97,7 @@ public: ...@@ -97,7 +97,7 @@ public:
if (torch::autograd::any_variable_requires_grad({mat})) { if (torch::autograd::any_variable_requires_grad({mat})) {
torch::optional<torch::Tensor> opt_value = torch::nullopt; torch::optional<torch::Tensor> opt_value = torch::nullopt;
if (has_value) if (has_value)
opt_value = value.index_select(0, csr2csc); opt_value = value.view({-1, 1}).index_select(0, csr2csc).view(-1);
grad_mat = std::get<0>(spmm_fw(colptr, row.index_select(0, csr2csc), grad_mat = std::get<0>(spmm_fw(colptr, row.index_select(0, csr2csc),
opt_value, grad_out, "sum")); opt_value, grad_out, "sum"));
...@@ -161,11 +161,12 @@ public: ...@@ -161,11 +161,12 @@ public:
auto grad_mat = Variable(); auto grad_mat = Variable();
if (torch::autograd::any_variable_requires_grad({mat})) { if (torch::autograd::any_variable_requires_grad({mat})) {
row = row.index_select(0, csr2csc); row = row.index_select(0, csr2csc);
rowcount = rowcount.toType(mat.scalar_type()).index_select(0, row); rowcount = rowcount.index_select(0, row).toType(mat.scalar_type());
rowcount.masked_fill_(rowcount < 1, 1); rowcount.masked_fill_(rowcount < 1, 1);
if (has_value > 0) if (has_value > 0)
rowcount = value.index_select(0, csr2csc).div(rowcount); rowcount =
value.view({-1, 1}).index_select(0, csr2csc).view(-1).div(rowcount);
else else
rowcount.pow_(-1); rowcount.pow_(-1);
...@@ -219,8 +220,10 @@ public: ...@@ -219,8 +220,10 @@ public:
auto grad_mat = Variable(); auto grad_mat = Variable();
if (torch::autograd::any_variable_requires_grad({mat})) { if (torch::autograd::any_variable_requires_grad({mat})) {
if (has_value > 0) { if (has_value > 0) {
value = value.index_select(0, arg_out.flatten()).view_as(arg_out); value = value.view({-1, 1})
value.mul_(grad_out); .index_select(0, arg_out.flatten())
.view_as(arg_out)
.mul_(grad_out);
} else } else
value = grad_out; value = grad_out;
...@@ -277,8 +280,10 @@ public: ...@@ -277,8 +280,10 @@ public:
auto grad_mat = Variable(); auto grad_mat = Variable();
if (torch::autograd::any_variable_requires_grad({mat})) { if (torch::autograd::any_variable_requires_grad({mat})) {
if (has_value > 0) { if (has_value > 0) {
value = value.index_select(0, arg_out.flatten()).view_as(arg_out); value = value.view({-1, 1})
value.mul_(grad_out); .index_select(0, arg_out.flatten())
.view_as(arg_out)
.mul_(grad_out);
} else } else
value = grad_out; value = grad_out;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment