Commit 1a4bdd30 authored by rusty1s's avatar rusty1s
Browse files

sparse matrix ops request floating point data types

parent 9fb07940
...@@ -163,17 +163,17 @@ from torch_sparse import spmm ...@@ -163,17 +163,17 @@ from torch_sparse import spmm
index = torch.tensor([[0, 0, 1, 2, 2], index = torch.tensor([[0, 0, 1, 2, 2],
[0, 2, 1, 0, 1]]) [0, 2, 1, 0, 1]])
value = torch.tensor([1, 2, 4, 1, 3]) value = torch.tensor([1, 2, 4, 1, 3], dtype=torch.float)
matrix = torch.tensor([[1, 4], [2, 5], [3, 6]]) matrix = torch.tensor([[1, 4], [2, 5], [3, 6]], dtype=torch.float)
out = spmm(index, value, 3, matrix) out = spmm(index, value, 3, matrix)
``` ```
``` ```
print(out) print(out)
tensor([[7, 16], tensor([[7.0, 16.0],
[8, 20], [8.0, 20.0],
[7, 19]]) [7.0, 19.0]])
``` ```
## Sparse Sparse Matrix Multiplication ## Sparse Sparse Matrix Multiplication
...@@ -206,10 +206,10 @@ Both input sparse matrices need to be **coalesced**. ...@@ -206,10 +206,10 @@ Both input sparse matrices need to be **coalesced**.
from torch_sparse import spspmm from torch_sparse import spspmm
indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]]) indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]])
valueA = torch.tensor([1, 2, 3, 4, 5]) valueA = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float)
indexB = torch.tensor([[0, 2], [1, 0]]) indexB = torch.tensor([[0, 2], [1, 0]])
valueB = torch.tensor([2, 4]) valueB = torch.tensor([2, 4], dtype=torch.float)
indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3, 3, 2) indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3, 3, 2)
``` ```
...@@ -219,7 +219,7 @@ print(index) ...@@ -219,7 +219,7 @@ print(index)
tensor([[0, 1, 2], tensor([[0, 1, 2],
[0, 1, 1]]) [0, 1, 1]])
print(value) print(value)
tensor([8, 6, 8]) tensor([8.0, 6.0, 8.0])
``` ```
## Running tests ## Running tests
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment