Unverified Commit 6aba92e9 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] fix test cases about datapipe (#6305)

parent 1328baf7
......@@ -10,15 +10,15 @@ import torch
@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
def test_CopyTo():
dp = gb.ItemSampler(gb.ItemSet(torch.randn(20)), 4)
item_sampler = gb.ItemSampler(gb.ItemSet(torch.randn(20)), 4)
# Invoke CopyTo via class constructor.
dp = gb.CopyTo(dp, "cuda")
dp = gb.CopyTo(item_sampler, "cuda")
for data in dp:
assert data.device.type == "cuda"
# Invoke CopyTo via functional form.
dp = dp.copy_to("cuda")
dp = item_sampler.copy_to("cuda")
for data in dp:
assert data.device.type == "cuda"
......
......@@ -17,17 +17,17 @@ def test_FeatureFetcher_invoke():
feature_store = gb.BasicFeatureStore(features)
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
datapipe = gb.ItemSampler(itemset, batch_size=2)
item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
# Invoke FeatureFetcher via class constructor.
datapipe = gb.NeighborSampler(datapipe, graph, fanouts)
datapipe = gb.NeighborSampler(item_sampler, graph, fanouts)
datapipe = gb.FeatureFetcher(datapipe, feature_store, ["a"], ["b"])
assert len(list(datapipe)) == 5
# Invoke FeatureFetcher via functional form.
datapipe = datapipe.sample_neighbor(graph, fanouts).fetch_feature(
datapipe = item_sampler.sample_neighbor(graph, fanouts).fetch_feature(
feature_store, ["a"], ["b"]
)
assert len(list(datapipe)) == 5
......
......@@ -7,15 +7,15 @@ from torchdata.datapipes.iter import Mapper
def test_SubgraphSampler_invoke():
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
datapipe = gb.ItemSampler(itemset, batch_size=2)
item_sampler = gb.ItemSampler(itemset, batch_size=2)
# Invoke via class constructor.
datapipe = gb.SubgraphSampler(datapipe)
datapipe = gb.SubgraphSampler(item_sampler)
with pytest.raises(NotImplementedError):
next(iter(datapipe))
# Invokde via functional form.
datapipe = datapipe.sample_subgraph()
datapipe = item_sampler.sample_subgraph()
with pytest.raises(NotImplementedError):
next(iter(datapipe))
......@@ -24,20 +24,20 @@ def test_SubgraphSampler_invoke():
def test_NeighborSampler_invoke(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15)
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
datapipe = gb.ItemSampler(itemset, batch_size=2)
item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
# Invoke via class constructor.
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
datapipe = Sampler(datapipe, graph, fanouts)
datapipe = Sampler(item_sampler, graph, fanouts)
assert len(list(datapipe)) == 5
# Invokde via functional form.
if labor:
datapipe = datapipe.sample_layer_neighbor(graph, fanouts)
datapipe = item_sampler.sample_layer_neighbor(graph, fanouts)
else:
datapipe = datapipe.sample_neighbor(graph, fanouts)
datapipe = item_sampler.sample_neighbor(graph, fanouts)
assert len(list(datapipe)) == 5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment