Commit ffec0a56 authored by rusty1s's avatar rusty1s
Browse files

coalesce call

parent 5788c855
...@@ -28,6 +28,9 @@ static void init_cusparse() { ...@@ -28,6 +28,9 @@ static void init_cusparse() {
} }
std::tuple<at::Tensor, at::Tensor> spspmm_cuda(at::Tensor A, at::Tensor B) { std::tuple<at::Tensor, at::Tensor> spspmm_cuda(at::Tensor A, at::Tensor B) {
A = A.coalesce();
B = B.coalesce();
init_cusparse(); init_cusparse();
auto m = A.size(0); auto m = A.size(0);
......
import torch
def SparseTensor(index, value, size):
t = torch.cuda if value.is_cuda else torch
SparseTensor = getattr(t.sparse, value.type().split('.')[-1])
return SparseTensor(index, value, size)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment