Unverified Commit e117adac authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] fix testcases on warning messages (#7054)

parent 571340da
...@@ -13,17 +13,19 @@ from . import gb_test_utils ...@@ -13,17 +13,19 @@ from . import gb_test_utils
@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test") @unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
def test_CopyTo(): def test_CopyTo():
item_sampler = gb.ItemSampler(gb.ItemSet(torch.randn(20)), 4) item_sampler = gb.ItemSampler(
gb.ItemSet(torch.arange(20), names="seed_nodes"), 4
)
# Invoke CopyTo via class constructor. # Invoke CopyTo via class constructor.
dp = gb.CopyTo(item_sampler, "cuda") dp = gb.CopyTo(item_sampler, "cuda")
for data in dp: for data in dp:
assert data.device.type == "cuda" assert data.seed_nodes.device.type == "cuda"
# Invoke CopyTo via functional form. # Invoke CopyTo via functional form.
dp = item_sampler.copy_to("cuda") dp = item_sampler.copy_to("cuda")
for data in dp: for data in dp:
assert data.device.type == "cuda" assert data.seed_nodes.device.type == "cuda"
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -77,7 +77,8 @@ def test_FeatureFetcher_with_edges_homo(): ...@@ -77,7 +77,8 @@ def test_FeatureFetcher_with_edges_homo():
[[random.randint(0, 10)] for _ in range(graph.total_num_edges)] [[random.randint(0, 10)] for _ in range(graph.total_num_edges)]
) )
def add_node_and_edge_ids(seeds): def add_node_and_edge_ids(minibatch):
seeds = minibatch.seed_nodes
subgraphs = [] subgraphs = []
for _ in range(3): for _ in range(3):
sampled_csc = gb.CSCFormatBase( sampled_csc = gb.CSCFormatBase(
...@@ -103,7 +104,7 @@ def test_FeatureFetcher_with_edges_homo(): ...@@ -103,7 +104,7 @@ def test_FeatureFetcher_with_edges_homo():
features[keys[1]] = gb.TorchBasedFeature(b) features[keys[1]] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features) feature_store = gb.BasicFeatureStore(features)
itemset = gb.ItemSet(torch.arange(10)) itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids) converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids)
fetcher_dp = gb.FeatureFetcher(converter_dp, feature_store, ["a"], ["b"]) fetcher_dp = gb.FeatureFetcher(converter_dp, feature_store, ["a"], ["b"])
...@@ -170,7 +171,8 @@ def test_FeatureFetcher_with_edges_hetero(): ...@@ -170,7 +171,8 @@ def test_FeatureFetcher_with_edges_hetero():
a = torch.tensor([[random.randint(0, 10)] for _ in range(20)]) a = torch.tensor([[random.randint(0, 10)] for _ in range(20)])
b = torch.tensor([[random.randint(0, 10)] for _ in range(50)]) b = torch.tensor([[random.randint(0, 10)] for _ in range(50)])
def add_node_and_edge_ids(seeds): def add_node_and_edge_ids(minibatch):
seeds = minibatch.seed_nodes
subgraphs = [] subgraphs = []
original_edge_ids = { original_edge_ids = {
"n1:e1:n2": torch.randint(0, 50, (10,)), "n1:e1:n2": torch.randint(0, 50, (10,)),
...@@ -213,7 +215,7 @@ def test_FeatureFetcher_with_edges_hetero(): ...@@ -213,7 +215,7 @@ def test_FeatureFetcher_with_edges_hetero():
itemset = gb.ItemSetDict( itemset = gb.ItemSetDict(
{ {
"n1": gb.ItemSet(torch.randint(0, 20, (10,))), "n1": gb.ItemSet(torch.randint(0, 20, (10,)), names="seed_nodes"),
} }
) )
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
......
...@@ -204,9 +204,16 @@ def test_ItemSet_graphs(batch_size, shuffle, drop_last): ...@@ -204,9 +204,16 @@ def test_ItemSet_graphs(batch_size, shuffle, drop_last):
dgl.rand_graph(num_nodes * (i + 1), num_edges * (i + 1)) dgl.rand_graph(num_nodes * (i + 1), num_edges * (i + 1))
for i in range(num_graphs) for i in range(num_graphs)
] ]
item_set = gb.ItemSet(graphs) item_set = gb.ItemSet(graphs, names="graphs")
# DGLGraph is not supported in gb.MiniBatch yet. Let's use a customized
# minibatcher to return the original graphs.
customized_minibatcher = lambda batch, names: batch
item_sampler = gb.ItemSampler( item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
minibatcher=customized_minibatcher,
) )
minibatch_num_nodes = [] minibatch_num_nodes = []
minibatch_num_edges = [] minibatch_num_edges = []
...@@ -459,13 +466,13 @@ def test_ItemSet_seeds_labels(batch_size, shuffle, drop_last): ...@@ -459,13 +466,13 @@ def test_ItemSet_seeds_labels(batch_size, shuffle, drop_last):
def test_append_with_other_datapipes(): def test_append_with_other_datapipes():
num_ids = 100 num_ids = 100
batch_size = 4 batch_size = 4
item_set = gb.ItemSet(torch.arange(0, num_ids)) item_set = gb.ItemSet(torch.arange(0, num_ids), names="seed_nodes")
data_pipe = gb.ItemSampler(item_set, batch_size) data_pipe = gb.ItemSampler(item_set, batch_size)
# torchdata.datapipes.iter.Enumerator # torchdata.datapipes.iter.Enumerator
data_pipe = data_pipe.enumerate() data_pipe = data_pipe.enumerate()
for i, (idx, data) in enumerate(data_pipe): for i, (idx, data) in enumerate(data_pipe):
assert i == idx assert i == idx
assert len(data) == batch_size assert len(data.seed_nodes) == batch_size
@pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("batch_size", [1, 4])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment