test_serial.py 727 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
# import torch
# from torch_cluster import serial_cluster
3
4


rusty1s's avatar
rusty1s committed
5
6
7
8
def test_serial():
    pass
    # ed_index = torch.LongTensor([[0, 0, 0, 1, 2, 3, 3, 3, 4, 5, 5, 5, 6, 6],
    #                              [2, 3, 6, 5, 0, 0, 4, 5, 3, 1, 3, 6, 0, 3]])
9
    # edge_attr = torch.Tensor([2, 2, 2, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2])
rusty1s's avatar
rusty1s committed
10
    # rid = torch.arange(edge_index.max() + 1, out=edge_index.new())
rusty1s's avatar
rusty1s committed
11
    # output = random_cluster(edge_index, rid=rid, perm_edges=False)
12

rusty1s's avatar
rusty1s committed
13
14
15
16
17
18
19
    # expected_output = [0, 1, 2, 0, 3, 1, 4]
    # assert output.tolist() == expected_output

    # TODO: Test only if conditions are met:
    # * at most two pairs with the same cluster
    # * pairs need to be neighbors of each other
    # TODO: Rename to serial