"git@developer.sourcefind.cn:OpenDAS/pytorch-encoding.git" did not exist on "abcee3c9316e634dae93b1923dfeda403ade7888"
Unverified Commit 81c7781b authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[Mics] Provide options for bidirectional edge (#6566)

parent 0024c7e1
...@@ -9,9 +9,10 @@ import scipy.sparse as sp ...@@ -9,9 +9,10 @@ import scipy.sparse as sp
import torch import torch
def rand_csc_graph(N, density): def rand_csc_graph(N, density, bidirection_edge=False):
adj = sp.random(N, N, density) adj = sp.random(N, N, density)
adj = adj + adj.T if bidirection_edge:
adj = adj + adj.T
adj = adj.tocsc() adj = adj.tocsc()
indptr = torch.LongTensor(adj.indptr) indptr = torch.LongTensor(adj.indptr)
......
...@@ -32,7 +32,7 @@ def test_NegativeSampler_invoke(): ...@@ -32,7 +32,7 @@ def test_NegativeSampler_invoke():
def test_UniformNegativeSampler_invoke(): def test_UniformNegativeSampler_invoke():
# Instantiate graph and required datapipes. # Instantiate graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(100, 0.05) graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True)
num_seeds = 30 num_seeds = 30
item_set = gb.ItemSet( item_set = gb.ItemSet(
torch.arange(0, 2 * num_seeds).reshape(-1, 2), names="node_pairs" torch.arange(0, 2 * num_seeds).reshape(-1, 2), names="node_pairs"
...@@ -69,7 +69,7 @@ def test_UniformNegativeSampler_invoke(): ...@@ -69,7 +69,7 @@ def test_UniformNegativeSampler_invoke():
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20]) @pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
def test_Uniform_NegativeSampler(negative_ratio): def test_Uniform_NegativeSampler(negative_ratio):
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb_test_utils.rand_csc_graph(100, 0.05) graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True)
num_seeds = 30 num_seeds = 30
item_set = gb.ItemSet( item_set = gb.ItemSet(
torch.arange(0, num_seeds * 2).reshape(-1, 2), names="node_pairs" torch.arange(0, num_seeds * 2).reshape(-1, 2), names="node_pairs"
......
...@@ -29,7 +29,7 @@ def test_CopyToWithMiniBatches(): ...@@ -29,7 +29,7 @@ def test_CopyToWithMiniBatches():
N = 16 N = 16
B = 2 B = 2
itemset = gb.ItemSet(torch.arange(N), names="seed_nodes") itemset = gb.ItemSet(torch.arange(N), names="seed_nodes")
graph = gb_test_utils.rand_csc_graph(100, 0.15) graph = gb_test_utils.rand_csc_graph(100, 0.15, bidirection_edge=True)
features = {} features = {}
keys = [("node", None, "a"), ("node", None, "b")] keys = [("node", None, "a"), ("node", None, "b")]
......
...@@ -8,7 +8,7 @@ from torchdata.datapipes.iter import Mapper ...@@ -8,7 +8,7 @@ from torchdata.datapipes.iter import Mapper
def test_FeatureFetcher_invoke(): def test_FeatureFetcher_invoke():
# Prepare graph and required datapipes. # Prepare graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(20, 0.15) graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)
a = torch.tensor( a = torch.tensor(
[[random.randint(0, 10)] for _ in range(graph.total_num_nodes)] [[random.randint(0, 10)] for _ in range(graph.total_num_nodes)]
) )
...@@ -40,7 +40,7 @@ def test_FeatureFetcher_invoke(): ...@@ -40,7 +40,7 @@ def test_FeatureFetcher_invoke():
def test_FeatureFetcher_homo(): def test_FeatureFetcher_homo():
graph = gb_test_utils.rand_csc_graph(20, 0.15) graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)
a = torch.tensor( a = torch.tensor(
[[random.randint(0, 10)] for _ in range(graph.total_num_nodes)] [[random.randint(0, 10)] for _ in range(graph.total_num_nodes)]
) )
...@@ -65,7 +65,7 @@ def test_FeatureFetcher_homo(): ...@@ -65,7 +65,7 @@ def test_FeatureFetcher_homo():
def test_FeatureFetcher_with_edges_homo(): def test_FeatureFetcher_with_edges_homo():
graph = gb_test_utils.rand_csc_graph(20, 0.15) graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)
a = torch.tensor( a = torch.tensor(
[[random.randint(0, 10)] for _ in range(graph.total_num_nodes)] [[random.randint(0, 10)] for _ in range(graph.total_num_nodes)]
) )
......
...@@ -7,7 +7,7 @@ def test_dgl_minibatch_converter(): ...@@ -7,7 +7,7 @@ def test_dgl_minibatch_converter():
N = 32 N = 32
B = 4 B = 4
itemset = gb.ItemSet(torch.arange(N), names="seed_nodes") itemset = gb.ItemSet(torch.arange(N), names="seed_nodes")
graph = gb_test_utils.rand_csc_graph(200, 0.15) graph = gb_test_utils.rand_csc_graph(200, 0.15, bidirection_edge=True)
features = {} features = {}
keys = [("node", None, "a"), ("node", None, "b")] keys = [("node", None, "a"), ("node", None, "b")]
......
...@@ -14,7 +14,7 @@ def test_DataLoader(): ...@@ -14,7 +14,7 @@ def test_DataLoader():
N = 40 N = 40
B = 4 B = 4
itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes") itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes")
graph = gb_test_utils.rand_csc_graph(200, 0.15) graph = gb_test_utils.rand_csc_graph(200, 0.15, bidirection_edge=True)
features = {} features = {}
keys = [("node", None, "a"), ("node", None, "b")] keys = [("node", None, "a"), ("node", None, "b")]
features[keys[0]] = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4)) features[keys[0]] = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
......
...@@ -11,7 +11,7 @@ def test_DataLoader(): ...@@ -11,7 +11,7 @@ def test_DataLoader():
N = 32 N = 32
B = 4 B = 4
itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes") itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes")
graph = gb_test_utils.rand_csc_graph(200, 0.15) graph = gb_test_utils.rand_csc_graph(200, 0.15, bidirection_edge=True)
features = {} features = {}
keys = [("node", None, "a"), ("node", None, "b")] keys = [("node", None, "a"), ("node", None, "b")]
......
...@@ -22,7 +22,7 @@ def test_SubgraphSampler_invoke(): ...@@ -22,7 +22,7 @@ def test_SubgraphSampler_invoke():
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_NeighborSampler_invoke(labor): def test_NeighborSampler_invoke(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15) graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes") itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
item_sampler = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
...@@ -43,7 +43,7 @@ def test_NeighborSampler_invoke(labor): ...@@ -43,7 +43,7 @@ def test_NeighborSampler_invoke(labor):
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_NeighborSampler_fanouts(labor): def test_NeighborSampler_fanouts(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15) graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes") itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
item_sampler = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
...@@ -67,7 +67,7 @@ def test_NeighborSampler_fanouts(labor): ...@@ -67,7 +67,7 @@ def test_NeighborSampler_fanouts(labor):
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Node(labor): def test_SubgraphSampler_Node(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15) graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes") itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
item_sampler = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
...@@ -84,7 +84,7 @@ def to_link_batch(data): ...@@ -84,7 +84,7 @@ def to_link_batch(data):
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link(labor): def test_SubgraphSampler_Link(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15) graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)
itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs") itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs")
item_sampler = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
...@@ -96,7 +96,7 @@ def test_SubgraphSampler_Link(labor): ...@@ -96,7 +96,7 @@ def test_SubgraphSampler_Link(labor):
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_With_Negative(labor): def test_SubgraphSampler_Link_With_Negative(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15) graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)
itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs") itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs")
item_sampler = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment