Commit bfdca11e authored by rusty1s's avatar rusty1s
Browse files

fix GPU device handling

parent c4c6db4a
...@@ -2,6 +2,7 @@ from typing import Optional ...@@ -2,6 +2,7 @@ from typing import Optional
import torch import torch
from torch import Tensor from torch import Tensor
from torch_sparse.storage import SparseStorage from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
...@@ -97,7 +98,7 @@ def get_diag(src: SparseTensor) -> Tensor: ...@@ -97,7 +98,7 @@ def get_diag(src: SparseTensor) -> Tensor:
row, col, value = src.coo() row, col, value = src.coo()
if value is None: if value is None:
value = torch.ones(row.size(0)) value = torch.ones(row.size(0), device=row.device)
sizes = list(value.size()) sizes = list(value.size())
sizes[0] = min(src.size(0), src.size(1)) sizes[0] = min(src.size(0), src.size(1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment