"examples/community/pipeline_stable_diffusion_boxdiff.py" did not exist on "aa82df52e719f22a51f2881ebe15d2904586160a"
Commit f9ce729f authored by rusty1s's avatar rusty1s
Browse files

fixed a crucial bug in ptr2ind

parent 1f175220
...@@ -57,7 +57,7 @@ torch::Tensor ptr2ind_cuda(torch::Tensor ptr, int64_t E) { ...@@ -57,7 +57,7 @@ torch::Tensor ptr2ind_cuda(torch::Tensor ptr, int64_t E) {
auto ptr_data = ptr.data_ptr<int64_t>(); auto ptr_data = ptr.data_ptr<int64_t>();
auto out_data = out.data_ptr<int64_t>(); auto out_data = out.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
ptr2ind_kernel<<<(ptr.numel() + THREADS - 1) / THREADS, THREADS, 0, stream>>>( ptr2ind_kernel<<<(ptr.numel() - 1 + THREADS - 1) / THREADS, THREADS, 0,
ptr_data, out_data, E, ptr.numel()); stream>>>(ptr_data, out_data, E, ptr.numel() - 1);
return out; return out;
} }
...@@ -39,7 +39,8 @@ def test_cat(device): ...@@ -39,7 +39,8 @@ def test_cat(device):
assert out.storage.has_rowptr() assert out.storage.has_rowptr()
assert out.storage.num_cached_keys() == 5 assert out.storage.num_cached_keys() == 5
mat1 = mat1.set_value_(torch.randn((mat1.nnz(), 4), device=device)) value = torch.randn((mat1.nnz(), 4), device=device)
mat1 = mat1.set_value_(value, layout='coo')
out = cat([mat1, mat1], dim=-1) out = cat([mat1, mat1], dim=-1)
assert out.storage.value().size() == (mat1.nnz(), 8) assert out.storage.value().size() == (mat1.nnz(), 8)
assert out.storage.has_row() assert out.storage.has_row()
......
...@@ -40,7 +40,7 @@ def test_spmm(dtype, device, reduce): ...@@ -40,7 +40,7 @@ def test_spmm(dtype, device, reduce):
out = matmul(src, other, reduce) out = matmul(src, other, reduce)
out.backward(grad_out) out.backward(grad_out)
assert torch.allclose(expected, out) assert torch.allclose(expected, out, atol=1e-6)
assert torch.allclose(expected_grad_value, value.grad, atol=1e-6) assert torch.allclose(expected_grad_value, value.grad, atol=1e-6)
assert torch.allclose(expected_grad_other, other.grad, atol=1e-6) assert torch.allclose(expected_grad_other, other.grad, atol=1e-6)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment