base.py 11.5 KB
Newer Older
1
2
"""Base types and utilities for Graph Bolt."""

3
from collections import deque
4
5
from dataclasses import dataclass

6
import torch
7
from torch.utils.data import functional_datapipe
8
9
10
11
12
13
from torchdata.datapipes.iter import IterDataPipe

from ..utils import recursive_apply

__all__ = [
    "CANONICAL_ETYPE_DELIMITER",
14
    "ORIGINAL_EDGE_ID",
15
16
17
    "etype_str_to_tuple",
    "etype_tuple_to_str",
    "CopyTo",
18
19
20
21
    "FutureWaiter",
    "Waiter",
    "Bufferer",
    "EndMarker",
22
    "isin",
23
    "index_select",
24
    "expand_indptr",
25
    "CSCFormatBase",
26
    "seed",
27
28
29
]

CANONICAL_ETYPE_DELIMITER = ":"
30
ORIGINAL_EDGE_ID = "_ORIGINAL_EDGE_ID"
31
32


33
34
35
36
37
38
39
40
41
42
43
def seed(val):
    """Set the random seed of Graphbolt.

    Parameters
    ----------
    val : int
        The seed.
    """
    torch.ops.graphbolt.set_seed(val)


44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def isin(elements, test_elements):
    """Tests if each element of elements is in test_elements. Returns a boolean
    tensor of the same shape as elements that is True for elements in
    test_elements and False otherwise.

    Parameters
    ----------
    elements : torch.Tensor
        A 1D tensor represents the input elements.
    test_elements : torch.Tensor
        A 1D tensor represents the values to test against for each input.

    Examples
    --------
    >>> isin(torch.tensor([1, 2, 3, 4]), torch.tensor([2, 3]))
    tensor([[False,  True,  True,  False]])
    """
    assert elements.dim() == 1, "Elements should be 1D tensor."
    assert test_elements.dim() == 1, "Test_elements should be 1D tensor."
    return torch.ops.graphbolt.isin(elements, test_elements)


66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def expand_indptr(indptr, dtype=None, node_ids=None, output_size=None):
    """Converts a given indptr offset tensor to a COO format tensor. If
    node_ids is not given, it is assumed to be equal to
    torch.arange(indptr.size(0) - 1, dtype=dtype, device=indptr.device).

    This is equivalent to

    .. code:: python

       if node_ids is None:
           node_ids = torch.arange(len(indptr) - 1, dtype=dtype, device=indptr.device)
       return node_ids.to(dtype).repeat_interleave(indptr.diff())

    Parameters
    ----------
    indptr : torch.Tensor
        A 1D tensor represents the csc_indptr tensor.
    dtype : Optional[torch.dtype]
        The dtype of the returned output tensor.
    node_ids : Optional[torch.Tensor]
        A 1D tensor represents the column node ids that the returned tensor will
        be populated with.
    output_size : Optional[int]
        The size of the output tensor. Should be equal to indptr[-1]. Using this
        argument avoids a stream synchronization to calculate the output shape.

    Returns
93
94
95
    -------
    torch.Tensor
        The converted COO tensor with values from node_ids.
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    """
    assert indptr.dim() == 1, "Indptr should be 1D tensor."
    assert not (
        node_ids is None and dtype is None
    ), "One of node_ids or dtype must be given."
    assert (
        node_ids is None or node_ids.dim() == 1
    ), "Node_ids should be 1D tensor."
    if dtype is None:
        dtype = node_ids.dtype
    return torch.ops.graphbolt.expand_indptr(
        indptr, dtype, node_ids, output_size
    )


111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def index_select(tensor, index):
    """Returns a new tensor which indexes the input tensor along dimension dim
    using the entries in index.

    The returned tensor has the same number of dimensions as the original tensor
    (tensor). The first dimension has the same size as the length of index;
    other dimensions have the same size as in the original tensor.

    When tensor is a pinned tensor and index.is_cuda is True, the operation runs
    on the CUDA device and the returned tensor will also be on CUDA.

    Parameters
    ----------
    tensor : torch.Tensor
        The input tensor.
    index : torch.Tensor
        The 1-D tensor containing the indices to index.

    Returns
130
131
132
    -------
    torch.Tensor
        The indexed input tensor, equivalent to tensor[index].
133
134
135
136
137
    """
    assert index.dim() == 1, "Index should be 1D tensor."
    return torch.ops.graphbolt.index_select(tensor, index)


138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def etype_tuple_to_str(c_etype):
    """Convert canonical etype from tuple to string.

    Examples
    --------
    >>> c_etype = ("user", "like", "item")
    >>> c_etype_str = _etype_tuple_to_str(c_etype)
    >>> print(c_etype_str)
    "user:like:item"
    """
    assert isinstance(c_etype, tuple) and len(c_etype) == 3, (
        "Passed-in canonical etype should be in format of (str, str, str). "
        f"But got {c_etype}."
    )
    return CANONICAL_ETYPE_DELIMITER.join(c_etype)


def etype_str_to_tuple(c_etype):
Rhett Ying's avatar
Rhett Ying committed
156
    """Convert canonical etype from string to tuple.
157
158
159
160
161
162
163
164

    Examples
    --------
    >>> c_etype_str = "user:like:item"
    >>> c_etype = _etype_str_to_tuple(c_etype_str)
    >>> print(c_etype)
    ("user", "like", "item")
    """
165
166
    if isinstance(c_etype, tuple):
        return c_etype
167
168
169
170
171
172
173
174
    ret = tuple(c_etype.split(CANONICAL_ETYPE_DELIMITER))
    assert len(ret) == 3, (
        "Passed-in canonical etype should be in format of 'str:str:str'. "
        f"But got {c_etype}."
    )
    return ret


175
176
177
def apply_to(x, device):
    """Apply `to` function to object x only if it has `to`."""

178
179
180
    return x.to(device) if hasattr(x, "to") else x


181
@functional_datapipe("copy_to")
182
183
class CopyTo(IterDataPipe):
    """DataPipe that transfers each element yielded from the previous DataPipe
184
185
    to the given device. For MiniBatch, only the related attributes
    (automatically inferred) will be transferred by default. If you want to
186
    transfer any other attributes, indicate them in the ``extra_attrs``.
187

188
189
    Functional name: :obj:`copy_to`.

190
191
    When ``data`` has ``to`` method implemented, ``CopyTo`` will be equivalent
    to
192
193
194
195
196
197

    .. code:: python

       for data in datapipe:
           yield data.to(device)

198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    For :class:`~dgl.graphbolt.MiniBatch`, only a part of attributes will be
    transferred to accelerate the process by default:

    - When ``seed_nodes`` is not None and ``node_pairs`` is None, node related
    task is inferred. Only ``labels``, ``sampled_subgraphs``, ``node_features``
    and ``edge_features`` will be transferred.

    - When ``node_pairs`` is not None and ``seed_nodes`` is None, edge/link
    related task is inferred. Only ``labels``, ``compacted_node_pairs``,
    ``compacted_negative_srcs``, ``compacted_negative_dsts``,
    ``sampled_subgraphs``, ``node_features`` and ``edge_features`` will be
    transferred.

    - Otherwise, all attributes will be transferred.

    - If you want some other attributes to be transferred as well, please
    specify the name in the ``extra_attrs``. For instance, the following code
    will copy ``seed_nodes`` to the GPU as well:

    .. code:: python

       datapipe = datapipe.copy_to(device="cuda", extra_attrs=["seed_nodes"])

221
222
223
224
225
226
    Parameters
    ----------
    datapipe : DataPipe
        The DataPipe.
    device : torch.device
        The PyTorch CUDA device.
227
    extra_attrs: List[string]
228
229
230
231
        The extra attributes of the data in the DataPipe you want to be carried
        to the specific device. The attributes specified in the ``extra_attrs``
        will be transferred regardless of the task inferred. It could also be
        applied to classes other than :class:`~dgl.graphbolt.MiniBatch`.
232
233
    """

234
    def __init__(self, datapipe, device, extra_attrs=None):
235
236
237
        super().__init__()
        self.datapipe = datapipe
        self.device = device
238
        self.extra_attrs = extra_attrs
239
240
241

    def __iter__(self):
        for data in self.datapipe:
242
            data = recursive_apply(data, apply_to, self.device)
243
244
245
246
247
248
249
250
251
            if self.extra_attrs is not None:
                for attr in self.extra_attrs:
                    setattr(
                        data,
                        attr,
                        recursive_apply(
                            getattr(data, attr), apply_to, self.device
                        ),
                    )
252
            yield data
253
254


255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
@functional_datapipe("mark_end")
class EndMarker(IterDataPipe):
    """Used to mark the end of a datapipe and is a no-op."""

    def __init__(self, datapipe):
        self.datapipe = datapipe

    def __iter__(self):
        yield from self.datapipe


@functional_datapipe("buffer")
class Bufferer(IterDataPipe):
    """Buffers items before yielding them.

    Parameters
    ----------
    datapipe : DataPipe
        The data pipeline.
    buffer_size : int, optional
        The size of the buffer which stores the fetched samples. If data coming
        from datapipe has latency spikes, consider setting to a higher value.
        Default is 1.
    """

    def __init__(self, datapipe, buffer_size=1):
        self.datapipe = datapipe
        if buffer_size <= 0:
            raise ValueError(
                "'buffer_size' is required to be a positive integer."
            )
        self.buffer = deque(maxlen=buffer_size)

    def __iter__(self):
        for data in self.datapipe:
            if len(self.buffer) < self.buffer.maxlen:
                self.buffer.append(data)
            else:
                return_data = self.buffer.popleft()
                self.buffer.append(data)
                yield return_data
        while len(self.buffer) > 0:
            yield self.buffer.popleft()


@functional_datapipe("wait")
class Waiter(IterDataPipe):
    """Calls the wait function of all items."""

    def __init__(self, datapipe):
        self.datapipe = datapipe

    def __iter__(self):
        for data in self.datapipe:
            data.wait()
            yield data


@functional_datapipe("wait_future")
class FutureWaiter(IterDataPipe):
    """Calls the result function of all items and returns their results."""

    def __init__(self, datapipe):
        self.datapipe = datapipe

    def __iter__(self):
        for data in self.datapipe:
            yield data.result()


325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
@dataclass
class CSCFormatBase:
    r"""Basic class representing data in Compressed Sparse Column (CSC) format.

    Examples
    --------
    >>> indptr = torch.tensor([0, 1, 3])
    >>> indices = torch.tensor([1, 4, 2])
    >>> csc_foramt_base = CSCFormatBase(indptr=indptr, indices=indices)
    >>> print(csc_format_base.indptr)
    ... torch.tensor([0, 1, 3])
    >>> print(csc_foramt_base)
    ... torch.tensor([1, 4, 2])
    """
    indptr: torch.Tensor = None
    indices: torch.Tensor = None
341

342
343
344
345
346
347
348
349
    def __init__(self, indptr: torch.Tensor, indices: torch.Tensor):
        self.indptr = indptr
        self.indices = indices
        if not indptr.is_cuda:
            assert self.indptr[-1] == len(
                self.indices
            ), "The last element of indptr should be the same as the length of indices."

350
351
352
    def __repr__(self) -> str:
        return _csc_format_base_str(self)

353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
    def to(self, device: torch.device) -> None:  # pylint: disable=invalid-name
        """Copy `CSCFormatBase` to the specified device using reflection."""

        for attr in dir(self):
            # Only copy member variables.
            if not callable(getattr(self, attr)) and not attr.startswith("__"):
                setattr(
                    self,
                    attr,
                    recursive_apply(
                        getattr(self, attr), lambda x: apply_to(x, device)
                    ),
                )

        return self

369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384

def _csc_format_base_str(csc_format_base: CSCFormatBase) -> str:
    final_str = "CSCFormatBase("

    def _add_indent(_str, indent):
        lines = _str.split("\n")
        lines = [lines[0]] + [" " * indent + line for line in lines[1:]]
        return "\n".join(lines)

    final_str += (
        f"indptr={_add_indent(str(csc_format_base.indptr), 21)},\n" + " " * 14
    )
    final_str += (
        f"indices={_add_indent(str(csc_format_base.indices), 22)},\n" + ")"
    )
    return final_str