Unverified Commit 3b37918b authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Add more detail to CopyTo docstring (#6904)

parent b9cf36c3
...@@ -104,17 +104,41 @@ class CopyTo(IterDataPipe): ...@@ -104,17 +104,41 @@ class CopyTo(IterDataPipe):
"""DataPipe that transfers each element yielded from the previous DataPipe """DataPipe that transfers each element yielded from the previous DataPipe
to the given device. For MiniBatch, only the related attributes to the given device. For MiniBatch, only the related attributes
(automatically inferred) will be transferred by default. If you want to (automatically inferred) will be transferred by default. If you want to
transfer any other attributes, indicate them in the `extra_attrs`. transfer any other attributes, indicate them in the ``extra_attrs``.
Functional name: :obj:`copy_to`. Functional name: :obj:`copy_to`.
This is equivalent to When ``data`` has ``to`` method implemented, ``CopyTo`` will be equivalent
to
.. code:: python .. code:: python
for data in datapipe: for data in datapipe:
yield data.to(device) yield data.to(device)
For :class:`~dgl.graphbolt.MiniBatch`, only a part of attributes will be
transferred to accelerate the process by default:
- When ``seed_nodes`` is not None and ``node_pairs`` is None, node related
task is inferred. Only ``labels``, ``sampled_subgraphs``, ``node_features``
and ``edge_features`` will be transferred.
- When ``node_pairs`` is not None and ``seed_nodes`` is None, edge/link
related task is inferred. Only ``labels``, ``compacted_node_pairs``,
``compacted_negative_srcs``, ``compacted_negative_dsts``,
``sampled_subgraphs``, ``node_features`` and ``edge_features`` will be
transferred.
- Otherwise, all attributes will be transferred.
- If you want some other attributes to be transferred as well, please
specify the name in the ``extra_attrs``. For instance, the following code
will copy ``seed_nodes`` to the GPU as well:
.. code:: python
datapipe = datapipe.copy_to(device="cuda", extra_attrs=["seed_nodes"])
Parameters Parameters
---------- ----------
datapipe : DataPipe datapipe : DataPipe
...@@ -122,8 +146,10 @@ class CopyTo(IterDataPipe): ...@@ -122,8 +146,10 @@ class CopyTo(IterDataPipe):
device : torch.device device : torch.device
The PyTorch CUDA device. The PyTorch CUDA device.
extra_attrs: List[string] extra_attrs: List[string]
The extra attributes in the MiniBatch you want to be carried to the The extra attributes of the data in the DataPipe you want to be carried
specific device. to the specific device. The attributes specified in the ``extra_attrs``
will be transferred regardless of the task inferred. It could also be
applied to classes other than :class:`~dgl.graphbolt.MiniBatch`.
""" """
def __init__(self, datapipe, device, extra_attrs=None): def __init__(self, datapipe, device, extra_attrs=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment