Unverified Commit 2968c9b2 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] remove SingleProcessDataLoader (#6663)

parent 018df054
...@@ -56,7 +56,7 @@ def test_integration_link_prediction(): ...@@ -56,7 +56,7 @@ def test_integration_link_prediction():
feature_store, node_feature_keys=["feat"], edge_feature_keys=["feat"] feature_store, node_feature_keys=["feat"], edge_feature_keys=["feat"]
) )
datapipe = datapipe.to_dgl() datapipe = datapipe.to_dgl()
dataloader = gb.SingleProcessDataLoader( dataloader = gb.DataLoader(
datapipe, datapipe,
) )
expected = [ expected = [
...@@ -71,13 +71,13 @@ def test_integration_link_prediction(): ...@@ -71,13 +71,13 @@ def test_integration_link_prediction():
[0.9634, 0.2294], [0.9634, 0.2294],
[0.5503, 0.8223]])}, [0.5503, 0.8223]])},
negative_node_pairs=(tensor([0, 1, 1, 1]), negative_node_pairs=(tensor([0, 1, 1, 1]),
tensor([0, 3, 4, 5])), tensor([4, 4, 1, 4])),
labels=None, labels=None,
input_nodes=None, input_nodes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2), blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2),
Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2)], Block(num_src_nodes=6, num_dst_nodes=5, num_edges=1)],
)""" )"""
), ),
str( str(
...@@ -90,7 +90,7 @@ def test_integration_link_prediction(): ...@@ -90,7 +90,7 @@ def test_integration_link_prediction():
[0.5160, 0.2486], [0.5160, 0.2486],
[0.6172, 0.7865]])}, [0.6172, 0.7865]])},
negative_node_pairs=(tensor([0, 1, 1, 2]), negative_node_pairs=(tensor([0, 1, 1, 2]),
tensor([1, 3, 4, 1])), tensor([1, 1, 3, 4])),
labels=None, labels=None,
input_nodes=None, input_nodes=None,
edge_features=[{}, edge_features=[{},
...@@ -104,17 +104,15 @@ def test_integration_link_prediction(): ...@@ -104,17 +104,15 @@ def test_integration_link_prediction():
tensor([0, 0])), tensor([0, 0])),
output_nodes=None, output_nodes=None,
node_features={'feat': tensor([[0.5160, 0.2486], node_features={'feat': tensor([[0.5160, 0.2486],
[0.5503, 0.8223], [0.5503, 0.8223]])},
[0.8672, 0.2276],
[0.9634, 0.2294]])},
negative_node_pairs=(tensor([0, 1]), negative_node_pairs=(tensor([0, 1]),
tensor([1, 2])), tensor([0, 0])),
labels=None, labels=None,
input_nodes=None, input_nodes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
blocks=[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2), blocks=[Block(num_src_nodes=2, num_dst_nodes=2, num_edges=1),
Block(num_src_nodes=4, num_dst_nodes=3, num_edges=2)], Block(num_src_nodes=2, num_dst_nodes=2, num_edges=1)],
)""" )"""
), ),
] ]
...@@ -172,7 +170,7 @@ def test_integration_node_classification(): ...@@ -172,7 +170,7 @@ def test_integration_node_classification():
feature_store, node_feature_keys=["feat"], edge_feature_keys=["feat"] feature_store, node_feature_keys=["feat"], edge_feature_keys=["feat"]
) )
datapipe = datapipe.to_dgl() datapipe = datapipe.to_dgl()
dataloader = gb.SingleProcessDataLoader( dataloader = gb.DataLoader(
datapipe, datapipe,
) )
expected = [ expected = [
...@@ -184,15 +182,14 @@ def test_integration_node_classification(): ...@@ -184,15 +182,14 @@ def test_integration_node_classification():
[0.8672, 0.2276], [0.8672, 0.2276],
[0.6172, 0.7865], [0.6172, 0.7865],
[0.2109, 0.1089], [0.2109, 0.1089],
[0.5503, 0.8223], [0.5503, 0.8223]])},
[0.9634, 0.2294]])},
negative_node_pairs=None, negative_node_pairs=None,
labels=None, labels=None,
input_nodes=None, input_nodes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
blocks=[Block(num_src_nodes=6, num_dst_nodes=5, num_edges=5), blocks=[Block(num_src_nodes=5, num_dst_nodes=4, num_edges=4),
Block(num_src_nodes=5, num_dst_nodes=4, num_edges=4)], Block(num_src_nodes=4, num_dst_nodes=4, num_edges=4)],
)""" )"""
), ),
str( str(
......
...@@ -759,9 +759,7 @@ def distributed_item_sampler_subprocess( ...@@ -759,9 +759,7 @@ def distributed_item_sampler_subprocess(
gb.BasicFeatureStore({}), gb.BasicFeatureStore({}),
[], [],
) )
data_loader = gb.MultiProcessDataLoader( data_loader = gb.DataLoader(feature_fetcher, num_workers=num_workers)
feature_fetcher, num_workers=num_workers
)
# Count the numbers of items and batches. # Count the numbers of items and batches.
num_items = 0 num_items = 0
......
...@@ -27,7 +27,7 @@ def test_dgl_minibatch_converter(): ...@@ -27,7 +27,7 @@ def test_dgl_minibatch_converter():
["a"], ["a"],
) )
dgl_converter = gb.DGLMiniBatchConverter(feature_fetcher) dgl_converter = gb.DGLMiniBatchConverter(feature_fetcher)
dataloader = gb.SingleProcessDataLoader(dgl_converter) dataloader = gb.DataLoader(dgl_converter)
assert len(list(dataloader)) == N // B assert len(list(dataloader)) == N // B
minibatch = next(iter(dataloader)) minibatch = next(iter(dataloader))
assert isinstance(minibatch, gb.DGLMiniBatch) assert isinstance(minibatch, gb.DGLMiniBatch)
...@@ -34,7 +34,7 @@ def test_DataLoader(): ...@@ -34,7 +34,7 @@ def test_DataLoader():
) )
device_transferrer = dgl.graphbolt.CopyTo(feature_fetcher, F.ctx()) device_transferrer = dgl.graphbolt.CopyTo(feature_fetcher, F.ctx())
dataloader = dgl.graphbolt.MultiProcessDataLoader( dataloader = dgl.graphbolt.DataLoader(
device_transferrer, device_transferrer,
num_workers=4, num_workers=4,
) )
......
...@@ -32,5 +32,5 @@ def test_DataLoader(): ...@@ -32,5 +32,5 @@ def test_DataLoader():
) )
device_transferrer = dgl.graphbolt.CopyTo(feature_fetcher, F.ctx()) device_transferrer = dgl.graphbolt.CopyTo(feature_fetcher, F.ctx())
dataloader = dgl.graphbolt.SingleProcessDataLoader(device_transferrer) dataloader = dgl.graphbolt.DataLoader(device_transferrer)
assert len(list(dataloader)) == N // B assert len(list(dataloader)) == N // B
...@@ -120,7 +120,7 @@ def create_dataloader( ...@@ -120,7 +120,7 @@ def create_dataloader(
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"]) datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
datapipe = datapipe.to_dgl() datapipe = datapipe.to_dgl()
datapipe = datapipe.copy_to(device) datapipe = datapipe.copy_to(device)
dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0) dataloader = gb.DataLoader(datapipe, num_workers=0)
return dataloader return dataloader
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment