Commit 1a4bdd30 authored by rusty1s's avatar rusty1s
Browse files

sparse matrix ops request floating point data types

parent 9fb07940
......@@ -163,17 +163,17 @@ from torch_sparse import spmm
index = torch.tensor([[0, 0, 1, 2, 2],
[0, 2, 1, 0, 1]])
value = torch.tensor([1, 2, 4, 1, 3])
matrix = torch.tensor([[1, 4], [2, 5], [3, 6]])
value = torch.tensor([1, 2, 4, 1, 3], dtype=torch.float)
matrix = torch.tensor([[1, 4], [2, 5], [3, 6]], dtype=torch.float)
out = spmm(index, value, 3, matrix)
```
```
print(out)
tensor([[7, 16],
[8, 20],
[7, 19]])
tensor([[7.0, 16.0],
[8.0, 20.0],
[7.0, 19.0]])
```
## Sparse Sparse Matrix Multiplication
......@@ -206,10 +206,10 @@ Both input sparse matrices need to be **coalesced**.
from torch_sparse import spspmm
indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]])
valueA = torch.tensor([1, 2, 3, 4, 5])
valueA = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float)
indexB = torch.tensor([[0, 2], [1, 0]])
valueB = torch.tensor([2, 4])
valueB = torch.tensor([2, 4], dtype=torch.float)
indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3, 3, 2)
```
......@@ -219,7 +219,7 @@ print(index)
tensor([[0, 1, 2],
[0, 1, 1]])
print(value)
tensor([8, 6, 8])
tensor([8.0, 6.0, 8.0])
```
## Running tests
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment