Unverified Commit 6aba92e9 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] fix test cases about datapipe (#6305)

parent 1328baf7
...@@ -10,15 +10,15 @@ import torch ...@@ -10,15 +10,15 @@ import torch
@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test") @unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
def test_CopyTo(): def test_CopyTo():
dp = gb.ItemSampler(gb.ItemSet(torch.randn(20)), 4) item_sampler = gb.ItemSampler(gb.ItemSet(torch.randn(20)), 4)
# Invoke CopyTo via class constructor. # Invoke CopyTo via class constructor.
dp = gb.CopyTo(dp, "cuda") dp = gb.CopyTo(item_sampler, "cuda")
for data in dp: for data in dp:
assert data.device.type == "cuda" assert data.device.type == "cuda"
# Invoke CopyTo via functional form. # Invoke CopyTo via functional form.
dp = dp.copy_to("cuda") dp = item_sampler.copy_to("cuda")
for data in dp: for data in dp:
assert data.device.type == "cuda" assert data.device.type == "cuda"
......
...@@ -17,17 +17,17 @@ def test_FeatureFetcher_invoke(): ...@@ -17,17 +17,17 @@ def test_FeatureFetcher_invoke():
feature_store = gb.BasicFeatureStore(features) feature_store = gb.BasicFeatureStore(features)
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes") itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
datapipe = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
# Invoke FeatureFetcher via class constructor. # Invoke FeatureFetcher via class constructor.
datapipe = gb.NeighborSampler(datapipe, graph, fanouts) datapipe = gb.NeighborSampler(item_sampler, graph, fanouts)
datapipe = gb.FeatureFetcher(datapipe, feature_store, ["a"], ["b"]) datapipe = gb.FeatureFetcher(datapipe, feature_store, ["a"], ["b"])
assert len(list(datapipe)) == 5 assert len(list(datapipe)) == 5
# Invoke FeatureFetcher via functional form. # Invoke FeatureFetcher via functional form.
datapipe = datapipe.sample_neighbor(graph, fanouts).fetch_feature( datapipe = item_sampler.sample_neighbor(graph, fanouts).fetch_feature(
feature_store, ["a"], ["b"] feature_store, ["a"], ["b"]
) )
assert len(list(datapipe)) == 5 assert len(list(datapipe)) == 5
......
...@@ -7,15 +7,15 @@ from torchdata.datapipes.iter import Mapper ...@@ -7,15 +7,15 @@ from torchdata.datapipes.iter import Mapper
def test_SubgraphSampler_invoke(): def test_SubgraphSampler_invoke():
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes") itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
datapipe = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
# Invoke via class constructor. # Invoke via class constructor.
datapipe = gb.SubgraphSampler(datapipe) datapipe = gb.SubgraphSampler(item_sampler)
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
next(iter(datapipe)) next(iter(datapipe))
# Invokde via functional form. # Invokde via functional form.
datapipe = datapipe.sample_subgraph() datapipe = item_sampler.sample_subgraph()
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
next(iter(datapipe)) next(iter(datapipe))
...@@ -24,20 +24,20 @@ def test_SubgraphSampler_invoke(): ...@@ -24,20 +24,20 @@ def test_SubgraphSampler_invoke():
def test_NeighborSampler_invoke(labor): def test_NeighborSampler_invoke(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15) graph = gb_test_utils.rand_csc_graph(20, 0.15)
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes") itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
datapipe = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
# Invoke via class constructor. # Invoke via class constructor.
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
datapipe = Sampler(datapipe, graph, fanouts) datapipe = Sampler(item_sampler, graph, fanouts)
assert len(list(datapipe)) == 5 assert len(list(datapipe)) == 5
# Invokde via functional form. # Invokde via functional form.
if labor: if labor:
datapipe = datapipe.sample_layer_neighbor(graph, fanouts) datapipe = item_sampler.sample_layer_neighbor(graph, fanouts)
else: else:
datapipe = datapipe.sample_neighbor(graph, fanouts) datapipe = item_sampler.sample_neighbor(graph, fanouts)
assert len(list(datapipe)) == 5 assert len(list(datapipe)) == 5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment