"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "8059fa6f8bbde5f1f26b172cc2bc91b6fade259d"
Unverified Commit 2968c9b2 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] remove SingleProcessDataLoader (#6663)

parent 018df054
...@@ -15,6 +15,7 @@ APIs ...@@ -15,6 +15,7 @@ APIs
:nosignatures: :nosignatures:
:template: graphbolt_classtemplate.rst :template: graphbolt_classtemplate.rst
DataLoader
Dataset Dataset
Task Task
ItemSet ItemSet
...@@ -35,17 +36,6 @@ APIs ...@@ -35,17 +36,6 @@ APIs
CopyTo CopyTo
DataLoaders
-----------
.. autosummary::
:toctree: ../../generated/
:nosignatures:
:template: graphbolt_classtemplate.rst
SingleProcessDataLoader
MultiProcessDataLoader
Standard Implementations Standard Implementations
------------------------- -------------------------
......
...@@ -40,7 +40,7 @@ The code below implements a classical neighbor sampler: ...@@ -40,7 +40,7 @@ The code below implements a classical neighbor sampler:
seed_nodes = input_nodes seed_nodes = input_nodes
return input_nodes, subgs return input_nodes, subgs
To use this sampler with :class:`~dgl.graphbolt.MultiProcessDataLoader`: To use this sampler with :class:`~dgl.graphbolt.DataLoader`:
.. code:: python .. code:: python
...@@ -49,7 +49,7 @@ To use this sampler with :class:`~dgl.graphbolt.MultiProcessDataLoader`: ...@@ -49,7 +49,7 @@ To use this sampler with :class:`~dgl.graphbolt.MultiProcessDataLoader`:
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"]) datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.to_dgl() datapipe = datapipe.to_dgl()
datapipe = datapipe.copy_to(device) datapipe = datapipe.copy_to(device)
dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0) dataloader = gb.DataLoader(datapipe, num_workers=0)
for data in dataloader: for data in dataloader:
input_features = data.node_features["feat"] input_features = data.node_features["feat"]
...@@ -95,7 +95,7 @@ can be used on heterogeneous graphs: ...@@ -95,7 +95,7 @@ can be used on heterogeneous graphs:
) )
datapipe = datapipe.to_dgl() datapipe = datapipe.to_dgl()
datapipe = datapipe.copy_to(device) datapipe = datapipe.copy_to(device)
dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0) dataloader = gb.DataLoader(datapipe, num_workers=0)
for data in dataloader: for data in dataloader:
input_features = { input_features = {
......
...@@ -40,7 +40,7 @@ edges(namely, node pairs) in the training set instead of the nodes. ...@@ -40,7 +40,7 @@ edges(namely, node pairs) in the training set instead of the nodes.
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"]) datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.to_dgl() datapipe = datapipe.to_dgl()
datapipe = datapipe.copy_to(device) datapipe = datapipe.copy_to(device)
dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0) dataloader = gb.DataLoader(datapipe, num_workers=0)
Iterating over the DataLoader will yield :class:`~dgl.graphbolt.DGLMiniBatch` Iterating over the DataLoader will yield :class:`~dgl.graphbolt.DGLMiniBatch`
which contains a list of specially created graphs representing the computation which contains a list of specially created graphs representing the computation
...@@ -93,7 +93,7 @@ You can use :func:`~dgl.graphbolt.exclude_seed_edges` alongside with ...@@ -93,7 +93,7 @@ You can use :func:`~dgl.graphbolt.exclude_seed_edges` alongside with
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"]) datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.to_dgl() datapipe = datapipe.to_dgl()
datapipe = datapipe.copy_to(device) datapipe = datapipe.copy_to(device)
dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0) dataloader = gb.DataLoader(datapipe, num_workers=0)
Adapt your model for minibatch training Adapt your model for minibatch training
...@@ -275,7 +275,7 @@ only difference is that the train_set is now an instance of ...@@ -275,7 +275,7 @@ only difference is that the train_set is now an instance of
) )
datapipe = datapipe.to_dgl() datapipe = datapipe.to_dgl()
datapipe = datapipe.copy_to(device) datapipe = datapipe.copy_to(device)
dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0) dataloader = gb.DataLoader(datapipe, num_workers=0)
Things become a little different if you wish to exclude the reverse Things become a little different if you wish to exclude the reverse
edges on heterogeneous graphs. On heterogeneous graphs, reverse edges edges on heterogeneous graphs. On heterogeneous graphs, reverse edges
......
...@@ -49,7 +49,7 @@ only one layer at a time. ...@@ -49,7 +49,7 @@ only one layer at a time.
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"]) datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.to_dgl() datapipe = datapipe.to_dgl()
datapipe = datapipe.copy_to(device) datapipe = datapipe.copy_to(device)
dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0) dataloader = gb.DataLoader(datapipe, num_workers=0)
Note that offline inference is implemented as a method of the GNN module Note that offline inference is implemented as a method of the GNN module
......
...@@ -29,7 +29,7 @@ The whole data loader pipeline is as follows: ...@@ -29,7 +29,7 @@ The whole data loader pipeline is as follows:
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"]) datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.to_dgl() datapipe = datapipe.to_dgl()
datapipe = datapipe.copy_to(device) datapipe = datapipe.copy_to(device)
dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0) dataloader = gb.DataLoader(datapipe, num_workers=0)
For the details about the builtin uniform negative sampler please see For the details about the builtin uniform negative sampler please see
...@@ -215,7 +215,7 @@ only difference is that you need to give edge types for feature fetching. ...@@ -215,7 +215,7 @@ only difference is that you need to give edge types for feature fetching.
) )
datapipe = datapipe.to_dgl() datapipe = datapipe.to_dgl()
datapipe = datapipe.copy_to(device) datapipe = datapipe.copy_to(device)
dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0) dataloader = gb.DataLoader(datapipe, num_workers=0)
If you want to give your own negative sampling function, just inherit from the If you want to give your own negative sampling function, just inherit from the
:class:`~dgl.graphbolt.NegativeSampler` class and override the :class:`~dgl.graphbolt.NegativeSampler` class and override the
......
...@@ -26,7 +26,7 @@ or the equivalent function-like interface :func:`~dgl.graphbolt.sample_neighbor` ...@@ -26,7 +26,7 @@ or the equivalent function-like interface :func:`~dgl.graphbolt.sample_neighbor`
which makes the node gather messages from its neighbors. which makes the node gather messages from its neighbors.
To use a sampler provided by DGL, one also need to combine it with To use a sampler provided by DGL, one also need to combine it with
:class:`~dgl.graphbolt.MultiProcessDataLoader`, which iterates :class:`~dgl.graphbolt.DataLoader`, which iterates
over a set of indices (nodes in this case) in minibatches. over a set of indices (nodes in this case) in minibatches.
For example, the following code creates a DataLoader that For example, the following code creates a DataLoader that
...@@ -52,7 +52,7 @@ putting the list of generated MFGs onto GPU. ...@@ -52,7 +52,7 @@ putting the list of generated MFGs onto GPU.
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"]) datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.to_dgl() datapipe = datapipe.to_dgl()
datapipe = datapipe.copy_to(device) datapipe = datapipe.copy_to(device)
dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0) dataloader = gb.DataLoader(datapipe, num_workers=0)
Iterating over the DataLoader will yield :class:`~dgl.graphbolt.DGLMiniBatch` Iterating over the DataLoader will yield :class:`~dgl.graphbolt.DGLMiniBatch`
...@@ -196,7 +196,7 @@ removed for simplicity): ...@@ -196,7 +196,7 @@ removed for simplicity):
The samplers provided by DGL also support heterogeneous graphs. The samplers provided by DGL also support heterogeneous graphs.
For example, one can still use the provided For example, one can still use the provided
:class:`~dgl.graphbolt.NeighborSampler` class and :class:`~dgl.graphbolt.NeighborSampler` class and
:class:`~dgl.graphbolt.MultiProcessDataLoader` class for :class:`~dgl.graphbolt.DataLoader` class for
stochastic training. The only difference is that the itemset is now an stochastic training. The only difference is that the itemset is now an
instance of :class:`~dgl.graphbolt.ItemSetDict` which is a dictionary instance of :class:`~dgl.graphbolt.ItemSetDict` which is a dictionary
of node types to node IDs. of node types to node IDs.
...@@ -217,7 +217,7 @@ of node types to node IDs. ...@@ -217,7 +217,7 @@ of node types to node IDs.
) )
datapipe = datapipe.to_dgl() datapipe = datapipe.to_dgl()
datapipe = datapipe.copy_to(device) datapipe = datapipe.copy_to(device)
dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0) dataloader = gb.DataLoader(datapipe, num_workers=0)
The training loop is almost the same as that of homogeneous graphs, The training loop is almost the same as that of homogeneous graphs,
except for the implementation of ``compute_loss`` that will take in two except for the implementation of ``compute_loss`` that will take in two
......
...@@ -23,7 +23,7 @@ generate a minibatch, including: ...@@ -23,7 +23,7 @@ generate a minibatch, including:
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"]) datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.to_dgl() datapipe = datapipe.to_dgl()
datapipe = datapipe.copy_to(device) datapipe = datapipe.copy_to(device)
dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0) dataloader = gb.DataLoader(datapipe, num_workers=0)
All these stages are implemented in separate All these stages are implemented in separate
`IterableDataPipe <https://pytorch.org/data/main/torchdata.datapipes.iter.html>`__ `IterableDataPipe <https://pytorch.org/data/main/torchdata.datapipes.iter.html>`__
...@@ -52,5 +52,5 @@ which prefetches elements from previous data pipes and puts them into a buffer. ...@@ -52,5 +52,5 @@ which prefetches elements from previous data pipes and puts them into a buffer.
Such prefetching is totally transparent to users and requires no extra code. It Such prefetching is totally transparent to users and requires no extra code. It
brings a significant performance boost to minibatch training of GNNs. brings a significant performance boost to minibatch training of GNNs.
Please refer to the source code of :class:`~dgl.graphbolt.MultiProcessDataLoader` Please refer to the source code of :class:`~dgl.graphbolt.DataLoader`
for more details. for more details.
...@@ -139,9 +139,7 @@ def create_dataloader( ...@@ -139,9 +139,7 @@ def create_dataloader(
# A CopyTo object copying data in the datapipe to a specified device.\ # A CopyTo object copying data in the datapipe to a specified device.\
############################################################################ ############################################################################
datapipe = datapipe.copy_to(device) datapipe = datapipe.copy_to(device)
dataloader = gb.MultiProcessDataLoader( dataloader = gb.DataLoader(datapipe, num_workers=args.num_workers)
datapipe, num_workers=args.num_workers
)
# Return the fully-initialized DataLoader object. # Return the fully-initialized DataLoader object.
return dataloader return dataloader
......
...@@ -159,9 +159,7 @@ class DataModule(LightningDataModule): ...@@ -159,9 +159,7 @@ class DataModule(LightningDataModule):
datapipe = sampler(self.graph, self.fanouts) datapipe = sampler(self.graph, self.fanouts)
datapipe = datapipe.fetch_feature(self.feature_store, ["feat"]) datapipe = datapipe.fetch_feature(self.feature_store, ["feat"])
datapipe = datapipe.to_dgl() datapipe = datapipe.to_dgl()
dataloader = gb.MultiProcessDataLoader( dataloader = gb.DataLoader(datapipe, num_workers=self.num_workers)
datapipe, num_workers=self.num_workers
)
return dataloader return dataloader
######################################################################## ########################################################################
......
...@@ -232,11 +232,11 @@ def create_dataloader(args, graph, features, itemset, is_train=True): ...@@ -232,11 +232,11 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# 'datapipe': The datapipe object to be used for data loading. # 'datapipe': The datapipe object to be used for data loading.
# 'args.num_workers': The number of processes to be used for data loading. # 'args.num_workers': The number of processes to be used for data loading.
# [Output]: # [Output]:
# A MultiProcessDataLoader object to handle data loading. # A DataLoader object to handle data loading.
# [Role]: # [Role]:
# Initialize a multi-process dataloader to load the data in parallel. # Initialize a multi-process dataloader to load the data in parallel.
############################################################################ ############################################################################
dataloader = gb.MultiProcessDataLoader( dataloader = gb.DataLoader(
datapipe, datapipe,
num_workers=args.num_workers, num_workers=args.num_workers,
) )
......
...@@ -148,16 +148,16 @@ def create_dataloader( ...@@ -148,16 +148,16 @@ def create_dataloader(
############################################################################ ############################################################################
# [Step-6]: # [Step-6]:
# gb.MultiProcessDataLoader() # gb.DataLoader()
# [Input]: # [Input]:
# 'datapipe': The datapipe object to be used for data loading. # 'datapipe': The datapipe object to be used for data loading.
# 'num_workers': The number of processes to be used for data loading. # 'num_workers': The number of processes to be used for data loading.
# [Output]: # [Output]:
# A MultiProcessDataLoader object to handle data loading. # A DataLoader object to handle data loading.
# [Role]: # [Role]:
# Initialize a multi-process dataloader to load the data in parallel. # Initialize a multi-process dataloader to load the data in parallel.
############################################################################ ############################################################################
dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=num_workers) dataloader = gb.DataLoader(datapipe, num_workers=num_workers)
# Return the fully-initialized DataLoader object. # Return the fully-initialized DataLoader object.
return dataloader return dataloader
......
...@@ -54,7 +54,7 @@ def create_dataloader(dateset, device, is_train=True): ...@@ -54,7 +54,7 @@ def create_dataloader(dateset, device, is_train=True):
datapipe = datapipe.copy_to(device) datapipe = datapipe.copy_to(device)
# Initiate the dataloader for the datapipe. # Initiate the dataloader for the datapipe.
return gb.SingleProcessDataLoader(datapipe) return gb.DataLoader(datapipe)
class GraphSAGE(nn.Module): class GraphSAGE(nn.Module):
......
...@@ -32,7 +32,7 @@ def create_dataloader(dateset, itemset, device): ...@@ -32,7 +32,7 @@ def create_dataloader(dateset, itemset, device):
datapipe = datapipe.copy_to(device) datapipe = datapipe.copy_to(device)
# Initiate the dataloader for the datapipe. # Initiate the dataloader for the datapipe.
return gb.SingleProcessDataLoader(datapipe) return gb.DataLoader(datapipe)
class GCN(nn.Module): class GCN(nn.Module):
......
...@@ -137,7 +137,7 @@ def create_dataloader( ...@@ -137,7 +137,7 @@ def create_dataloader(
# Create a DataLoader from the datapipe. # Create a DataLoader from the datapipe.
# `num_workers`: # `num_workers`:
# The number of worker processes to use for data loading. # The number of worker processes to use for data loading.
return gb.MultiProcessDataLoader(datapipe, num_workers=num_workers) return gb.DataLoader(datapipe, num_workers=num_workers)
def extract_embed(node_embed, input_nodes): def extract_embed(node_embed, input_nodes):
......
...@@ -136,7 +136,7 @@ def create_dataloader(A, fanouts, ids, features, device): ...@@ -136,7 +136,7 @@ def create_dataloader(A, fanouts, ids, features, device):
# Use grapbolt to fetch features. # Use grapbolt to fetch features.
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"]) datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device) datapipe = datapipe.copy_to(device)
dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=4) dataloader = gb.DataLoader(datapipe, num_workers=4)
return dataloader return dataloader
......
...@@ -145,7 +145,7 @@ ...@@ -145,7 +145,7 @@
"datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n", "datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n",
"datapipe = datapipe.to_dgl()\n", "datapipe = datapipe.to_dgl()\n",
"datapipe = datapipe.copy_to(device)\n", "datapipe = datapipe.copy_to(device)\n",
"train_dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0)" "train_dataloader = gb.DataLoader(datapipe, num_workers=0)"
], ],
"metadata": { "metadata": {
"id": "LZgXGfBvYijJ" "id": "LZgXGfBvYijJ"
...@@ -344,7 +344,7 @@ ...@@ -344,7 +344,7 @@
"datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n", "datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n",
"datapipe = datapipe.to_dgl()\n", "datapipe = datapipe.to_dgl()\n",
"datapipe = datapipe.copy_to(device)\n", "datapipe = datapipe.copy_to(device)\n",
"eval_dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0)\n", "eval_dataloader = gb.DataLoader(datapipe, num_workers=0)\n",
"\n", "\n",
"logits = []\n", "logits = []\n",
"labels = []\n", "labels = []\n",
......
...@@ -138,7 +138,7 @@ ...@@ -138,7 +138,7 @@
"source": [ "source": [
"## Defining Neighbor Sampler and Data Loader in DGL\n", "## Defining Neighbor Sampler and Data Loader in DGL\n",
"\n", "\n",
"DGL provides tools to iterate over the dataset in minibatches while generating the computation dependencies to compute their outputs with the MFGs above. For node classification, you can use `dgl.graphbolt.MultiProcessDataLoader` for iterating over the dataset. It accepts a data pipe that generates minibatches of nodes and their labels, sample neighbors for each node, and generate the computation dependencies in the form of MFGs. Feature fetching, block creation and copying to target device are also supported. All these operations are split into separate stages in the data pipe, so that you can customize the data pipeline by inserting your own operations.\n", "DGL provides tools to iterate over the dataset in minibatches while generating the computation dependencies to compute their outputs with the MFGs above. For node classification, you can use `dgl.graphbolt.DataLoader` for iterating over the dataset. It accepts a data pipe that generates minibatches of nodes and their labels, sample neighbors for each node, and generate the computation dependencies in the form of MFGs. Feature fetching, block creation and copying to target device are also supported. All these operations are split into separate stages in the data pipe, so that you can customize the data pipeline by inserting your own operations.\n",
"\n", "\n",
"Let’s say that each node will gather messages from 4 neighbors on each layer. The code defining the data loader and neighbor sampler will look like the following.\n" "Let’s say that each node will gather messages from 4 neighbors on each layer. The code defining the data loader and neighbor sampler will look like the following.\n"
], ],
...@@ -154,7 +154,7 @@ ...@@ -154,7 +154,7 @@
"datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n", "datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n",
"datapipe = datapipe.to_dgl()\n", "datapipe = datapipe.to_dgl()\n",
"datapipe = datapipe.copy_to(device)\n", "datapipe = datapipe.copy_to(device)\n",
"train_dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0)" "train_dataloader = gb.DataLoader(datapipe, num_workers=0)"
], ],
"metadata": { "metadata": {
"id": "yQVYDO0ZbBvi" "id": "yQVYDO0ZbBvi"
...@@ -287,7 +287,7 @@ ...@@ -287,7 +287,7 @@
"datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n", "datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n",
"datapipe = datapipe.to_dgl()\n", "datapipe = datapipe.to_dgl()\n",
"datapipe = datapipe.copy_to(device)\n", "datapipe = datapipe.copy_to(device)\n",
"valid_dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0)\n", "valid_dataloader = gb.DataLoader(datapipe, num_workers=0)\n",
"\n", "\n",
"\n", "\n",
"import sklearn.metrics" "import sklearn.metrics"
......
...@@ -12,8 +12,7 @@ from .item_sampler import ItemSampler ...@@ -12,8 +12,7 @@ from .item_sampler import ItemSampler
__all__ = [ __all__ = [
"SingleProcessDataLoader", "DataLoader",
"MultiProcessDataLoader",
] ]
...@@ -36,38 +35,6 @@ def _find_and_wrap_parent( ...@@ -36,38 +35,6 @@ def _find_and_wrap_parent(
) )
class SingleProcessDataLoader(torch.utils.data.DataLoader):
"""Single process DataLoader.
Iterates over the data pipeline in the main process.
Parameters
----------
datapipe : DataPipe
The data pipeline.
"""
# In the single process dataloader case, we don't need to do any
# modifications to the datapipe, and we just PyTorch's native
# dataloader as-is.
#
# The exception is that batch_size should be None, since we already
# have minibatch sampling and collating in ItemSampler.
def __init__(self, datapipe):
datapipe_graph = dp_utils.traverse_dps(datapipe)
datapipe_adjlist = datapipe_graph_to_adjlist(datapipe_graph)
# Cut datapipe at CopyTo and wrap with prefetcher. This enables the
# data pipeline up to the CopyTo operation to run in a separate thread.
_find_and_wrap_parent(
datapipe_graph,
datapipe_adjlist,
CopyTo,
dp.iter.Prefetcher,
buffer_size=2,
)
super().__init__(datapipe, batch_size=None, num_workers=0)
class MultiprocessingWrapper(dp.iter.IterDataPipe): class MultiprocessingWrapper(dp.iter.IterDataPipe):
"""Wraps a datapipe with multiprocessing. """Wraps a datapipe with multiprocessing.
...@@ -97,7 +64,7 @@ class MultiprocessingWrapper(dp.iter.IterDataPipe): ...@@ -97,7 +64,7 @@ class MultiprocessingWrapper(dp.iter.IterDataPipe):
yield from self.dataloader yield from self.dataloader
class MultiProcessDataLoader(torch.utils.data.DataLoader): class DataLoader(torch.utils.data.DataLoader):
"""Multiprocessing DataLoader. """Multiprocessing DataLoader.
Iterates over the data pipeline with everything before feature fetching Iterates over the data pipeline with everything before feature fetching
...@@ -112,8 +79,7 @@ class MultiProcessDataLoader(torch.utils.data.DataLoader): ...@@ -112,8 +79,7 @@ class MultiProcessDataLoader(torch.utils.data.DataLoader):
datapipe : DataPipe datapipe : DataPipe
The data pipeline. The data pipeline.
num_workers : int, optional num_workers : int, optional
Number of worker processes. Default is 0, which is identical to Number of worker processes. Default is 0.
:class:`SingleProcessDataLoader`.
persistent_workers : bool, optional persistent_workers : bool, optional
If True, the data loader will not shut down the worker processes after a If True, the data loader will not shut down the worker processes after a
dataset has been consumed once. This allows to maintain the workers dataset has been consumed once. This allows to maintain the workers
......
...@@ -584,7 +584,7 @@ class DistributedItemSampler(ItemSampler): ...@@ -584,7 +584,7 @@ class DistributedItemSampler(ItemSampler):
>>> item_set, batch_size=2, shuffle=False, drop_last=False, >>> item_set, batch_size=2, shuffle=False, drop_last=False,
>>> drop_uneven_inputs=False >>> drop_uneven_inputs=False
>>> ) >>> )
>>> data_loader = gb.SingleProcessDataLoader(item_sampler) >>> data_loader = gb.DataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)}) >>> print(f"Replica#{proc_id}: {list(data_loader)})
Replica#0: [tensor([0, 1]), tensor([2, 3])] Replica#0: [tensor([0, 1]), tensor([2, 3])]
Replica#1: [tensor([4, 5]), tensor([6, 7])] Replica#1: [tensor([4, 5]), tensor([6, 7])]
...@@ -597,7 +597,7 @@ class DistributedItemSampler(ItemSampler): ...@@ -597,7 +597,7 @@ class DistributedItemSampler(ItemSampler):
>>> item_set, batch_size=2, shuffle=False, drop_last=True, >>> item_set, batch_size=2, shuffle=False, drop_last=True,
>>> drop_uneven_inputs=False >>> drop_uneven_inputs=False
>>> ) >>> )
>>> data_loader = gb.SingleProcessDataLoader(item_sampler) >>> data_loader = gb.DataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)}) >>> print(f"Replica#{proc_id}: {list(data_loader)})
Replica#0: [tensor([0, 1]), tensor([2, 3])] Replica#0: [tensor([0, 1]), tensor([2, 3])]
Replica#1: [tensor([4, 5]), tensor([6, 7])] Replica#1: [tensor([4, 5]), tensor([6, 7])]
...@@ -610,7 +610,7 @@ class DistributedItemSampler(ItemSampler): ...@@ -610,7 +610,7 @@ class DistributedItemSampler(ItemSampler):
>>> item_set, batch_size=2, shuffle=False, drop_last=False, >>> item_set, batch_size=2, shuffle=False, drop_last=False,
>>> drop_uneven_inputs=True >>> drop_uneven_inputs=True
>>> ) >>> )
>>> data_loader = gb.SingleProcessDataLoader(item_sampler) >>> data_loader = gb.DataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)}) >>> print(f"Replica#{proc_id}: {list(data_loader)})
Replica#0: [tensor([0, 1]), tensor([2, 3])] Replica#0: [tensor([0, 1]), tensor([2, 3])]
Replica#1: [tensor([4, 5]), tensor([6, 7])] Replica#1: [tensor([4, 5]), tensor([6, 7])]
...@@ -623,7 +623,7 @@ class DistributedItemSampler(ItemSampler): ...@@ -623,7 +623,7 @@ class DistributedItemSampler(ItemSampler):
>>> item_set, batch_size=2, shuffle=False, drop_last=True, >>> item_set, batch_size=2, shuffle=False, drop_last=True,
>>> drop_uneven_inputs=True >>> drop_uneven_inputs=True
>>> ) >>> )
>>> data_loader = gb.SingleProcessDataLoader(item_sampler) >>> data_loader = gb.DataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)}) >>> print(f"Replica#{proc_id}: {list(data_loader)})
Replica#0: [tensor([0, 1])] Replica#0: [tensor([0, 1])]
Replica#1: [tensor([4, 5])] Replica#1: [tensor([4, 5])]
...@@ -636,7 +636,7 @@ class DistributedItemSampler(ItemSampler): ...@@ -636,7 +636,7 @@ class DistributedItemSampler(ItemSampler):
>>> item_set, batch_size=2, shuffle=True, drop_last=True, >>> item_set, batch_size=2, shuffle=True, drop_last=True,
>>> drop_uneven_inputs=False >>> drop_uneven_inputs=False
>>> ) >>> )
>>> data_loader = gb.SingleProcessDataLoader(item_sampler) >>> data_loader = gb.DataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)}) >>> print(f"Replica#{proc_id}: {list(data_loader)})
(One possible output:) (One possible output:)
Replica#0: [tensor([3, 2]), tensor([0, 1])] Replica#0: [tensor([3, 2]), tensor([0, 1])]
...@@ -650,7 +650,7 @@ class DistributedItemSampler(ItemSampler): ...@@ -650,7 +650,7 @@ class DistributedItemSampler(ItemSampler):
>>> item_set, batch_size=2, shuffle=True, drop_last=True, >>> item_set, batch_size=2, shuffle=True, drop_last=True,
>>> drop_uneven_inputs=True >>> drop_uneven_inputs=True
>>> ) >>> )
>>> data_loader = gb.SingleProcessDataLoader(item_sampler) >>> data_loader = gb.DataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)}) >>> print(f"Replica#{proc_id}: {list(data_loader)})
(One possible output:) (One possible output:)
Replica#0: [tensor([1, 3])] Replica#0: [tensor([1, 3])]
......
...@@ -2085,7 +2085,7 @@ def test_OnDiskDataset_homogeneous(include_original_edge_id): ...@@ -2085,7 +2085,7 @@ def test_OnDiskDataset_homogeneous(include_original_edge_id):
dataset.feature, node_feature_keys=["feat"] dataset.feature, node_feature_keys=["feat"]
) )
datapipe = datapipe.to_dgl() datapipe = datapipe.to_dgl()
dataloader = gb.MultiProcessDataLoader(datapipe) dataloader = gb.DataLoader(datapipe)
for _ in dataloader: for _ in dataloader:
pass pass
...@@ -2157,7 +2157,7 @@ def test_OnDiskDataset_heterogeneous(include_original_edge_id): ...@@ -2157,7 +2157,7 @@ def test_OnDiskDataset_heterogeneous(include_original_edge_id):
dataset.feature, node_feature_keys={"user": ["feat"]} dataset.feature, node_feature_keys={"user": ["feat"]}
) )
datapipe = datapipe.to_dgl() datapipe = datapipe.to_dgl()
dataloader = gb.MultiProcessDataLoader(datapipe) dataloader = gb.DataLoader(datapipe)
for _ in dataloader: for _ in dataloader:
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment