Commit a597b822 authored by rusty1s's avatar rusty1s
Browse files

only cpu test

parent eda4b3d7
import pytest
import torch import torch
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
from .utils import devices
def test_saint_subgraph():
@pytest.mark.parametrize('device', devices)
def test_saint_subgraph(device):
row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4]) row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4])
col = torch.tensor([1, 2, 0, 2, 0, 1, 3, 2, 4, 3]) col = torch.tensor([1, 2, 0, 2, 0, 1, 3, 2, 4, 3])
adj = SparseTensor(row=row, col=col).to(device) adj = SparseTensor(row=row, col=col)
node_idx = torch.tensor([0, 1, 2]) node_idx = torch.tensor([0, 1, 2])
adj, edge_index = adj.saint_subgraph(node_idx) adj, edge_index = adj.saint_subgraph(node_idx)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment