fused_csc_sampling_graph.py 45.4 KB
Newer Older
1
2
"""CSC format sampling graph."""
# pylint: disable= invalid-name
3
from typing import Dict, Optional, Union
4
5
6

import torch

7
8
from dgl.utils import recursive_apply

9
from ...base import EID, ETYPE, NID, NTYPE
10
11
from ...convert import to_homogeneous
from ...heterograph import DGLGraph
12
from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID
13
from ..sampling_graph import SamplingGraph
14
from .sampled_subgraph_impl import CSCFormatBase, SampledSubgraphImpl
15

16

17
__all__ = [
18
    "FusedCSCSamplingGraph",
19
    "fused_csc_sampling_graph",
20
21
22
23
24
    "load_from_shared_memory",
    "from_dglgraph",
]


25
class FusedCSCSamplingGraph(SamplingGraph):
26
    r"""A sampling graph in CSC format."""
27
28
29
30
31

    def __repr__(self):
        return _csc_sampling_graph_str(self)

    def __init__(
32
33
        self,
        c_csc_graph: torch.ScriptObject,
34
    ):
35
        super().__init__()
36
37
38
        self._c_csc_graph = c_csc_graph

    @property
39
    def total_num_nodes(self) -> int:
40
41
42
43
44
45
46
47
48
49
        """Returns the number of nodes in the graph.

        Returns
        -------
        int
            The number of rows in the dense format.
        """
        return self._c_csc_graph.num_nodes()

    @property
50
    def total_num_edges(self) -> int:
51
52
53
54
55
56
57
58
59
        """Returns the number of edges in the graph.

        Returns
        -------
        int
            The number of edges in the graph.
        """
        return self._c_csc_graph.num_edges()

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    @property
    def num_nodes(self) -> Union[int, Dict[str, int]]:
        """The number of nodes in the graph.
        - If the graph is homogenous, returns an integer.
        - If the graph is heterogenous, returns a dictionary.

        Returns
        -------
        Union[int, Dict[str, int]]
            The number of nodes. Integer indicates the total nodes number of a
            homogenous graph; dict indicates nodes number per node types of a
            heterogenous graph.

        Examples
        --------
        >>> import dgl.graphbolt as gb, torch
        >>> total_num_nodes = 5
        >>> total_num_edges = 12
        >>> ntypes = {"N0": 0, "N1": 1}
        >>> etypes = {"N0:R0:N0": 0, "N0:R1:N1": 1,
80
        ...     "N1:R2:N0": 2, "N1:R3:N1": 3}
81
82
83
84
        >>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
        >>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor(
85
        ...     [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
86
        >>> graph = gb.fused_csc_sampling_graph(indptr, indices,
87
88
89
90
        ...     node_type_offset=node_type_offset,
        ...     type_per_edge=type_per_edge,
        ...     node_type_to_id=ntypes,
        ...     edge_type_to_id=etypes)
91
        >>> print(graph.num_nodes)
92
        {'N0': 2, 'N1': 3}
93
94
95
96
97
        """

        offset = self.node_type_offset

        # Homogenous.
98
        if offset is None or self.node_type_to_id is None:
99
100
101
102
103
            return self._c_csc_graph.num_nodes()

        # Heterogenous
        else:
            num_nodes_per_type = {
104
                _type: (offset[_idx + 1] - offset[_idx]).item()
105
                for _type, _idx in self.node_type_to_id.items()
106
107
108
109
            }

            return num_nodes_per_type

110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    @property
    def num_edges(self) -> Union[int, Dict[str, int]]:
        """The number of edges in the graph.
        - If the graph is homogenous, returns an integer.
        - If the graph is heterogenous, returns a dictionary.

        Returns
        -------
        Union[int, Dict[str, int]]
            The number of edges. Integer indicates the total edges number of a
            homogenous graph; dict indicates edges number per edge types of a
            heterogenous graph.

        Examples
        --------
        >>> import dgl.graphbolt as gb, torch
        >>> total_num_nodes = 5
        >>> total_num_edges = 12
        >>> ntypes = {"N0": 0, "N1": 1}
        >>> etypes = {"N0:R0:N0": 0, "N0:R1:N1": 1,
        ...     "N1:R2:N0": 2, "N1:R3:N1": 3}
        >>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
        >>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor(
        ...     [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
        >>> metadata = gb.GraphMetadata(ntypes, etypes)
137
        >>> graph = gb.fused_csc_sampling_graph(indptr, indices, node_type_offset,
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        ...     type_per_edge, None, metadata)
        >>> print(graph.num_edges)
        {'N0:R0:N0': 2, 'N0:R1:N1': 1, 'N1:R2:N0': 2, 'N1:R3:N1': 3}
        """

        type_per_edge = self.type_per_edge

        # Homogenous.
        if type_per_edge is None or self.edge_type_to_id is None:
            return self._c_csc_graph.num_edges()

        # Heterogenous
        bincount = torch.bincount(type_per_edge)
        num_edges_per_type = {}
        for etype, etype_id in self.edge_type_to_id.items():
            if etype_id < len(bincount):
                num_edges_per_type[etype] = bincount[etype_id].item()
            else:
                num_edges_per_type[etype] = 0
        return num_edges_per_type

159
160
161
162
163
164
165
166
    @property
    def csc_indptr(self) -> torch.tensor:
        """Returns the indices pointer in the CSC graph.

        Returns
        -------
        torch.tensor
            The indices pointer in the CSC graph. An integer tensor with
167
            shape `(total_num_nodes+1,)`.
168
169
170
        """
        return self._c_csc_graph.csc_indptr()

171
172
173
174
175
    @csc_indptr.setter
    def csc_indptr(self, csc_indptr: torch.tensor) -> None:
        """Sets the indices pointer in the CSC graph."""
        self._c_csc_graph.set_csc_indptr(csc_indptr)

176
177
178
179
180
181
182
183
    @property
    def indices(self) -> torch.tensor:
        """Returns the indices in the CSC graph.

        Returns
        -------
        torch.tensor
            The indices in the CSC graph. An integer tensor with shape
184
            `(total_num_edges,)`.
185
186
187
188
189
190
191
192

        Notes
        -------
        It is assumed that edges of each node are already sorted by edge type
        ids.
        """
        return self._c_csc_graph.indices()

193
194
195
196
197
    @indices.setter
    def indices(self, indices: torch.tensor) -> None:
        """Sets the indices in the CSC graph."""
        self._c_csc_graph.set_indices(indices)

198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    @property
    def node_type_offset(self) -> Optional[torch.Tensor]:
        """Returns the node type offset tensor if present.

        Returns
        -------
        torch.Tensor or None
            If present, returns a 1D integer tensor of shape
            `(num_node_types + 1,)`. The tensor is in ascending order as nodes
            of the same type have continuous IDs, and larger node IDs are
            paired with larger node type IDs. The first value is 0 and last
            value is the number of nodes. And nodes with IDs between
            `node_type_offset_[i]~node_type_offset_[i+1]` are of type id 'i'.

        """
        return self._c_csc_graph.node_type_offset()

215
216
217
218
219
220
221
    @node_type_offset.setter
    def node_type_offset(
        self, node_type_offset: Optional[torch.Tensor]
    ) -> None:
        """Sets the node type offset tensor if present."""
        self._c_csc_graph.set_node_type_offset(node_type_offset)

222
223
224
225
226
227
228
    @property
    def type_per_edge(self) -> Optional[torch.Tensor]:
        """Returns the edge type tensor if present.

        Returns
        -------
        torch.Tensor or None
229
            If present, returns a 1D integer tensor of shape (total_num_edges,)
230
231
232
233
            containing the type of each edge in the graph.
        """
        return self._c_csc_graph.type_per_edge()

234
235
236
237
238
    @type_per_edge.setter
    def type_per_edge(self, type_per_edge: Optional[torch.Tensor]) -> None:
        """Sets the edge type tensor if present."""
        self._c_csc_graph.set_type_per_edge(type_per_edge)

239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
    @property
    def node_type_to_id(self) -> Optional[Dict[str, int]]:
        """Returns the node type to id dictionary if present.

        Returns
        -------
        Dict[str, int] or None
            If present, returns a dictionary mapping node type to node type
            id.
        """
        return self._c_csc_graph.node_type_to_id()

    @node_type_to_id.setter
    def node_type_to_id(
        self, node_type_to_id: Optional[Dict[str, int]]
    ) -> None:
        """Sets the node type to id dictionary if present."""
        self._c_csc_graph.set_node_type_to_id(node_type_to_id)

    @property
    def edge_type_to_id(self) -> Optional[Dict[str, int]]:
        """Returns the edge type to id dictionary if present.

        Returns
        -------
        Dict[str, int] or None
            If present, returns a dictionary mapping edge type to edge type
            id.
        """
        return self._c_csc_graph.edge_type_to_id()

    @edge_type_to_id.setter
    def edge_type_to_id(
        self, edge_type_to_id: Optional[Dict[str, int]]
    ) -> None:
        """Sets the edge type to id dictionary if present."""
        self._c_csc_graph.set_edge_type_to_id(edge_type_to_id)

277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
    @property
    def node_attributes(self) -> Optional[Dict[str, torch.Tensor]]:
        """Returns the node attributes dictionary.

        Returns
        -------
        Dict[str, torch.Tensor] or None
            If present, returns a dictionary of node attributes. Each key
            represents the attribute's name, while the corresponding value
            holds the attribute's specific value. The length of each value
            should match the total number of nodes."
        """
        return self._c_csc_graph.node_attributes()

    @node_attributes.setter
    def node_attributes(
        self, node_attributes: Optional[Dict[str, torch.Tensor]]
    ) -> None:
        """Sets the node attributes dictionary."""
        self._c_csc_graph.set_node_attributes(node_attributes)

298
299
300
301
302
303
    @property
    def edge_attributes(self) -> Optional[Dict[str, torch.Tensor]]:
        """Returns the edge attributes dictionary.

        Returns
        -------
304
        Dict[str, torch.Tensor] or None
305
306
307
308
309
310
311
            If present, returns a dictionary of edge attributes. Each key
            represents the attribute's name, while the corresponding value
            holds the attribute's specific value. The length of each value
            should match the total number of edges."
        """
        return self._c_csc_graph.edge_attributes()

312
313
314
315
316
317
318
    @edge_attributes.setter
    def edge_attributes(
        self, edge_attributes: Optional[Dict[str, torch.Tensor]]
    ) -> None:
        """Sets the edge attributes dictionary."""
        self._c_csc_graph.set_edge_attributes(edge_attributes)

319
    def in_subgraph(
320
321
        self,
        nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
322
    ) -> SampledSubgraphImpl:
323
324
325
        """Return the subgraph induced on the inbound edges of the given nodes.

        An in subgraph is equivalent to creating a new graph using the incoming
326
327
        edges of the given nodes. Subgraph is compacted according to the order
        of passed-in `nodes`.
328
329
330

        Parameters
        ----------
331
332
333
334
335
336
        nodes: torch.Tensor or Dict[str, torch.Tensor]
            IDs of the given seed nodes.
              - If `nodes` is a tensor: It means the graph is homogeneous
                graph, and ids inside are homogeneous ids.
              - If `nodes` is a dictionary: The keys should be node type and
                ids inside are heterogeneous ids.
337
338
339

        Returns
        -------
340
        SampledSubgraphImpl
341
            The in subgraph.
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356

        Examples
        --------
        >>> import dgl.graphbolt as gb
        >>> import torch
        >>> total_num_nodes = 5
        >>> total_num_edges = 12
        >>> ntypes = {"N0": 0, "N1": 1}
        >>> etypes = {
        ...     "N0:R0:N0": 0, "N0:R1:N1": 1, "N1:R2:N0": 2, "N1:R3:N1": 3}
        >>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
        >>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor(
        ...     [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
357
        >>> graph = gb.fused_csc_sampling_graph(indptr, indices,
358
359
360
361
        ...     node_type_offset=node_type_offset,
        ...     type_per_edge=type_per_edge,
        ...     node_type_to_id=ntypes,
        ...     edge_type_to_id=etypes)
362
363
        >>> nodes = {"N0":torch.LongTensor([1]), "N1":torch.LongTensor([1, 2])}
        >>> in_subgraph = graph.in_subgraph(nodes)
364
        >>> print(in_subgraph.sampled_csc)
365
366
367
368
369
370
371
372
373
        {'N0:R0:N0': CSCFormatBase(indptr=tensor([0, 0]),
              indices=tensor([], dtype=torch.int64),
        ), 'N0:R1:N1': CSCFormatBase(indptr=tensor([0, 1, 2]),
                    indices=tensor([1, 0]),
        ), 'N1:R2:N0': CSCFormatBase(indptr=tensor([0, 2]),
                    indices=tensor([0, 1]),
        ), 'N1:R3:N1': CSCFormatBase(indptr=tensor([0, 1, 3]),
                    indices=tensor([0, 1, 2]),
        )}
374
        """
375
376
        if isinstance(nodes, dict):
            nodes = self._convert_to_homogeneous_nodes(nodes)
377
378
379
380
381
382
        # Ensure nodes is 1-D tensor.
        assert nodes.dim() == 1, "Nodes should be 1-D tensor."
        # Ensure that there are no duplicate nodes.
        assert len(torch.unique(nodes)) == len(
            nodes
        ), "Nodes cannot have duplicate values."
383

384
        _in_subgraph = self._c_csc_graph.in_subgraph(nodes)
385
        return self._convert_to_sampled_subgraph(_in_subgraph)
386

387
    def _convert_to_homogeneous_nodes(self, nodes, timestamps=None):
388
        homogeneous_nodes = []
389
        homogeneous_timestamps = []
390
        for ntype, ids in nodes.items():
391
            ntype_id = self.node_type_to_id[ntype]
392
393
394
            homogeneous_nodes.append(
                ids + self.node_type_offset[ntype_id].item()
            )
395
396
397
398
399
400
            if timestamps is not None:
                homogeneous_timestamps.append(timestamps[ntype])
        if timestamps is not None:
            return torch.cat(homogeneous_nodes), torch.cat(
                homogeneous_timestamps
            )
401
402
        return torch.cat(homogeneous_nodes)

403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
    def _convert_to_sampled_subgraph(
        self,
        C_sampled_subgraph: torch.ScriptObject,
    ) -> SampledSubgraphImpl:
        """An internal function used to convert a fused homogeneous sampled
        subgraph to general struct 'SampledSubgraphImpl'."""
        indptr = C_sampled_subgraph.indptr
        indices = C_sampled_subgraph.indices
        type_per_edge = C_sampled_subgraph.type_per_edge
        column = C_sampled_subgraph.original_column_node_ids
        original_edge_ids = C_sampled_subgraph.original_edge_ids

        has_original_eids = (
            self.edge_attributes is not None
            and ORIGINAL_EDGE_ID in self.edge_attributes
        )
        if has_original_eids:
420
421
422
423
424
            original_edge_ids = torch.index_select(
                self.edge_attributes[ORIGINAL_EDGE_ID],
                dim=0,
                index=original_edge_ids,
            )
425
426
        if type_per_edge is None:
            # The sampled graph is already a homogeneous graph.
427
            sampled_csc = CSCFormatBase(indptr=indptr, indices=indices)
428
        else:
429
430
431
432
433
434
435
            self.node_type_offset = self.node_type_offset.to(column.device)
            # 1. Find node types for each nodes in column.
            node_types = (
                torch.searchsorted(self.node_type_offset, column, right=True)
                - 1
            )

436
            original_hetero_edge_ids = {}
437
438
439
440
441
442
443
444
445
446
447
448
449
            sub_indices = {}
            sub_indptr = {}
            # 2. For loop each node type.
            for ntype, ntype_id in self.node_type_to_id.items():
                # Get all nodes of a specific node type in column.
                nids = torch.nonzero(node_types == ntype_id).view(-1)
                nids_original_indptr = indptr[nids + 1]
                for etype, etype_id in self.edge_type_to_id.items():
                    src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
                    if dst_ntype != ntype:
                        continue
                    # Get all edge ids of a specific edge type.
                    eids = torch.nonzero(type_per_edge == etype_id).view(-1)
450
                    src_ntype_id = self.node_type_to_id[src_ntype]
451
452
453
454
455
                    sub_indices[etype] = (
                        indices[eids] - self.node_type_offset[src_ntype_id]
                    )
                    cum_edges = torch.searchsorted(
                        eids, nids_original_indptr, right=False
456
                    )
457
458
                    sub_indptr[etype] = torch.cat(
                        (torch.tensor([0], device=indptr.device), cum_edges)
459
460
                    )
                    if has_original_eids:
461
462
463
                        original_hetero_edge_ids[etype] = original_edge_ids[
                            eids
                        ]
464
465
            if has_original_eids:
                original_edge_ids = original_hetero_edge_ids
466
            sampled_csc = {
467
                etype: CSCFormatBase(
468
469
                    indptr=sub_indptr[etype],
                    indices=sub_indices[etype],
470
                )
471
                for etype in self.edge_type_to_id.keys()
472
473
            }
        return SampledSubgraphImpl(
474
            sampled_csc=sampled_csc,
475
476
477
            original_edge_ids=original_edge_ids,
        )

478
    def sample_neighbors(
479
480
481
482
483
        self,
        nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
        fanouts: torch.Tensor,
        replace: bool = False,
        probs_name: Optional[str] = None,
484
    ) -> SampledSubgraphImpl:
485
486
487
488
489
490
491
        """Sample neighboring edges of the given nodes and return the induced
        subgraph.

        Parameters
        ----------
        nodes: torch.Tensor or Dict[str, torch.Tensor]
            IDs of the given seed nodes.
492
493
494
495
              - If `nodes` is a tensor: It means the graph is homogeneous
                graph, and ids inside are homogeneous ids.
              - If `nodes` is a dictionary: The keys should be node type and
                ids inside are heterogeneous ids.
496
497
498
499
500
501
502
503
504
505
        fanouts: torch.Tensor
            The number of edges to be sampled for each node with or without
            considering edge types.
              - When the length is 1, it indicates that the fanout applies to
                all neighbors of the node as a collective, regardless of the
                edge type.
              - Otherwise, the length should equal to the number of edge
                types, and each fanout value corresponds to a specific edge
                type of the nodes.
            The value of each fanout should be >= 0 or = -1.
506
507
508
509
510
              - When the value is -1, all neighbors (with non-zero probability,
                if weighted) will be sampled once regardless of replacement. It
                is equivalent to selecting all neighbors with non-zero
                probability when the fanout is >= the number of neighbors (and
                replace is set to false).
511
512
513
514
515
516
517
              - When the value is a non-negative integer, it serves as a
                minimum threshold for selecting neighbors.
        replace: bool
            Boolean indicating whether the sample is preformed with or
            without replacement. If True, a value can be selected multiple
            times. Otherwise, each value can be selected only once.
        probs_name: str, optional
518
519
            An optional string specifying the name of an edge attribute used.
            This attribute tensor should contain (unnormalized) probabilities
520
521
522
            corresponding to each neighboring edge of a node. It must be a 1D
            floating-point or boolean tensor, with the number of elements
            equalling the total number of edges.
523

524
525
        Returns
        -------
526
        SampledSubgraphImpl
527
528
529
530
531
            The sampled subgraph.

        Examples
        --------
        >>> import dgl.graphbolt as gb
532
533
534
535
536
537
538
        >>> import torch
        >>> ntypes = {"n1": 0, "n2": 1}
        >>> etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
        >>> indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
        >>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
539
        >>> graph = gb.fused_csc_sampling_graph(indptr, indices,
540
541
542
543
        ...     node_type_offset=node_type_offset,
        ...     type_per_edge=type_per_edge,
        ...     node_type_to_id=ntypes,
        ...     edge_type_to_id=etypes)
544
        >>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
545
546
        >>> fanouts = torch.tensor([1, 1])
        >>> subgraph = graph.sample_neighbors(nodes, fanouts)
547
        >>> print(subgraph.sampled_csc)
548
549
550
551
552
        {'n1:e1:n2': CSCFormatBase(indptr=tensor([0, 1]),
                    indices=tensor([0]),
        ), 'n2:e2:n1': CSCFormatBase(indptr=tensor([0, 1]),
                    indices=tensor([2]),
        )}
553
554
        """
        if isinstance(nodes, dict):
555
            nodes = self._convert_to_homogeneous_nodes(nodes)
556
557

        C_sampled_subgraph = self._sample_neighbors(
558
            nodes, fanouts, replace, probs_name
559
        )
560
        return self._convert_to_sampled_subgraph(C_sampled_subgraph)
561

562
563
    def _check_sampler_arguments(self, nodes, fanouts, probs_name):
        assert nodes.dim() == 1, "Nodes should be 1-D tensor."
564
565
566
567
        assert nodes.dtype == self.indices.dtype, (
            f"Data type of nodes must be consistent with "
            f"indices.dtype({self.indices.dtype}), but got {nodes.dtype}."
        )
568
569
        assert fanouts.dim() == 1, "Fanouts should be 1-D tensor."
        expected_fanout_len = 1
570
571
        if self.edge_type_to_id:
            expected_fanout_len = len(self.edge_type_to_id)
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
        assert len(fanouts) in [
            expected_fanout_len,
            1,
        ], "Fanouts should have the same number of elements as etypes or \
            should have a length of 1."
        if fanouts.size(0) > 1:
            assert (
                self.type_per_edge is not None
            ), "To perform sampling for each edge type (when the length of \
                `fanouts` > 1), the graph must include edge type information."
        assert torch.all(
            (fanouts >= 0) | (fanouts == -1)
        ), "Fanouts should consist of values that are either -1 or \
            greater than or equal to 0."
        if probs_name:
            assert (
                probs_name in self.edge_attributes
            ), f"Unknown edge attribute '{probs_name}'."
            probs_or_mask = self.edge_attributes[probs_name]
            assert probs_or_mask.dim() == 1, "Probs should be 1-D tensor."
            assert (
593
                probs_or_mask.size(0) == self.total_num_edges
594
595
596
597
598
599
600
601
602
603
            ), "Probs should have the same number of elements as the number \
                of edges."
            assert probs_or_mask.dtype in [
                torch.bool,
                torch.float16,
                torch.bfloat16,
                torch.float32,
                torch.float64,
            ], "Probs should have a floating-point or boolean data type."

604
    def _sample_neighbors(
605
606
        self,
        nodes: torch.Tensor,
607
        fanouts: torch.Tensor,
608
        replace: bool = False,
609
        probs_name: Optional[str] = None,
610
    ) -> torch.ScriptObject:
611
612
613
614
615
616
617
        """Sample neighboring edges of the given nodes and return the induced
        subgraph.

        Parameters
        ----------
        nodes: torch.Tensor
            IDs of the given seed nodes.
618
619
620
621
        fanouts: torch.Tensor
            The number of edges to be sampled for each node with or without
            considering edge types.
              - When the length is 1, it indicates that the fanout applies to
622
623
                all neighbors of the node as a collective, regardless of the
                edge type.
624
              - Otherwise, the length should equal to the number of edge
625
626
                types, and each fanout value corresponds to a specific edge
                type of the nodes.
627
            The value of each fanout should be >= 0 or = -1.
628
629
630
631
632
              - When the value is -1, all neighbors (with non-zero probability,
                if weighted) will be sampled once regardless of replacement. It
                is equivalent to selecting all neighbors with non-zero
                probability when the fanout is >= the number of neighbors (and
                replace is set to false).
633
              - When the value is a non-negative integer, it serves as a
634
635
                minimum threshold for selecting neighbors.
        replace: bool
636
637
638
            Boolean indicating whether the sample is preformed with or
            without replacement. If True, a value can be selected multiple
            times. Otherwise, each value can be selected only once.
639
640
641
642
643
644
        probs_name: str, optional
            An optional string specifying the name of an edge attribute. This
            attribute tensor should contain (unnormalized) probabilities
            corresponding to each neighboring edge of a node. It must be a 1D
            floating-point or boolean tensor, with the number of elements
            equalling the total number of edges.
645

646
        Returns
Rhett Ying's avatar
Rhett Ying committed
647
        -------
648
        torch.classes.graphbolt.SampledSubgraph
649
            The sampled C subgraph.
650
651
        """
        # Ensure nodes is 1-D tensor.
652
        self._check_sampler_arguments(nodes, fanouts, probs_name)
653
        has_original_eids = (
654
655
656
            self.edge_attributes is not None
            and ORIGINAL_EDGE_ID in self.edge_attributes
        )
657
        return self._c_csc_graph.sample_neighbors(
658
659
660
661
662
663
            nodes,
            fanouts.tolist(),
            replace,
            False,
            has_original_eids,
            probs_name,
664
665
666
667
668
669
670
671
        )

    def sample_layer_neighbors(
        self,
        nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
        fanouts: torch.Tensor,
        replace: bool = False,
        probs_name: Optional[str] = None,
672
    ) -> SampledSubgraphImpl:
673
        """Sample neighboring edges of the given nodes and return the induced
674
675
676
        subgraph via layer-neighbor sampling from the NeurIPS 2023 paper
        `Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs
        <https://arxiv.org/abs/2210.13339>`__
677
678
679
680
681

        Parameters
        ----------
        nodes: torch.Tensor or Dict[str, torch.Tensor]
            IDs of the given seed nodes.
682
683
684
685
              - If `nodes` is a tensor: It means the graph is homogeneous
                graph, and ids inside are homogeneous ids.
              - If `nodes` is a dictionary: The keys should be node type and
                ids inside are heterogeneous ids.
686
687
688
689
690
691
692
693
694
695
        fanouts: torch.Tensor
            The number of edges to be sampled for each node with or without
            considering edge types.
              - When the length is 1, it indicates that the fanout applies to
                all neighbors of the node as a collective, regardless of the
                edge type.
              - Otherwise, the length should equal to the number of edge
                types, and each fanout value corresponds to a specific edge
                type of the nodes.
            The value of each fanout should be >= 0 or = -1.
696
697
698
699
700
              - When the value is -1, all neighbors (with non-zero probability,
                if weighted) will be sampled once regardless of replacement. It
                is equivalent to selecting all neighbors with non-zero
                probability when the fanout is >= the number of neighbors (and
                replace is set to false).
701
702
703
704
705
706
707
708
709
710
711
712
              - When the value is a non-negative integer, it serves as a
                minimum threshold for selecting neighbors.
        replace: bool
            Boolean indicating whether the sample is preformed with or
            without replacement. If True, a value can be selected multiple
            times. Otherwise, each value can be selected only once.
        probs_name: str, optional
            An optional string specifying the name of an edge attribute. This
            attribute tensor should contain (unnormalized) probabilities
            corresponding to each neighboring edge of a node. It must be a 1D
            floating-point or boolean tensor, with the number of elements
            equalling the total number of edges.
713

714
715
        Returns
        -------
716
        SampledSubgraphImpl
717
718
719
720
            The sampled subgraph.

        Examples
        --------
721
722
723
724
725
726
727
728
        >>> import dgl.graphbolt as gb
        >>> import torch
        >>> ntypes = {"n1": 0, "n2": 1}
        >>> etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
        >>> indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
        >>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
729
        >>> graph = gb.fused_csc_sampling_graph(indptr, indices,
730
731
732
733
        ...     node_type_offset=node_type_offset,
        ...     type_per_edge=type_per_edge,
        ...     node_type_to_id=ntypes,
        ...     edge_type_to_id=etypes)
734
735
736
        >>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
        >>> fanouts = torch.tensor([1, 1])
        >>> subgraph = graph.sample_layer_neighbors(nodes, fanouts)
737
        >>> print(subgraph.sampled_csc)
738
739
740
741
742
        {'n1:e1:n2': CSCFormatBase(indptr=tensor([0, 1]),
                    indices=tensor([0]),
        ), 'n2:e2:n1': CSCFormatBase(indptr=tensor([0, 1]),
                    indices=tensor([2]),
        )}
743
744
745
746
747
        """
        if isinstance(nodes, dict):
            nodes = self._convert_to_homogeneous_nodes(nodes)

        self._check_sampler_arguments(nodes, fanouts, probs_name)
748
749
750
751
        has_original_eids = (
            self.edge_attributes is not None
            and ORIGINAL_EDGE_ID in self.edge_attributes
        )
752
        C_sampled_subgraph = self._c_csc_graph.sample_neighbors(
753
754
755
756
757
758
            nodes,
            fanouts.tolist(),
            replace,
            True,
            has_original_eids,
            probs_name,
759
        )
760
        return self._convert_to_sampled_subgraph(C_sampled_subgraph)
761

762
    def temporal_sample_neighbors(
763
        self,
764
765
        nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
        input_nodes_timestamp: Union[torch.Tensor, Dict[str, torch.Tensor]],
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
        fanouts: torch.Tensor,
        replace: bool = False,
        probs_name: Optional[str] = None,
        node_timestamp_attr_name: Optional[str] = None,
        edge_timestamp_attr_name: Optional[str] = None,
    ) -> torch.ScriptObject:
        """Temporally Sample neighboring edges of the given nodes and return the induced
        subgraph.

        If `node_timestamp_attr_name` or `edge_timestamp_attr_name` is given,
        the sampled neighbors or edges of an input node must have a timestamp
        that is no later than that of the input node.

        Parameters
        ----------
        nodes: torch.Tensor
            IDs of the given seed nodes.
        input_nodes_timestamp: torch.Tensor
            Timestamps of the given seed nodes.
        fanouts: torch.Tensor
            The number of edges to be sampled for each node with or without
            considering edge types.
              - When the length is 1, it indicates that the fanout applies to
                all neighbors of the node as a collective, regardless of the
                edge type.
              - Otherwise, the length should equal to the number of edge
                types, and each fanout value corresponds to a specific edge
                type of the nodes.
            The value of each fanout should be >= 0 or = -1.
              - When the value is -1, all neighbors (with non-zero probability,
                if weighted) will be sampled once regardless of replacement. It
                is equivalent to selecting all neighbors with non-zero
                probability when the fanout is >= the number of neighbors (and
                replace is set to false).
              - When the value is a non-negative integer, it serves as a
                minimum threshold for selecting neighbors.
        replace: bool
            Boolean indicating whether the sample is preformed with or
            without replacement. If True, a value can be selected multiple
            times. Otherwise, each value can be selected only once.
        probs_name: str, optional
            An optional string specifying the name of an edge attribute. This
            attribute tensor should contain (unnormalized) probabilities
            corresponding to each neighboring edge of a node. It must be a 1D
            floating-point or boolean tensor, with the number of elements
            equalling the total number of edges.
        node_timestamp_attr_name: str, optional
            An optional string specifying the name of an node attribute.
        edge_timestamp_attr_name: str, optional
            An optional string specifying the name of an edge attribute.

        Returns
        -------
819
        SampledSubgraphImpl
820
            The sampled subgraph.
821
        """
822
823
824
825
826
        if isinstance(nodes, dict):
            nodes, input_nodes_timestamp = self._convert_to_homogeneous_nodes(
                nodes, input_nodes_timestamp
            )

827
828
829
830
831
832
        # Ensure nodes is 1-D tensor.
        self._check_sampler_arguments(nodes, fanouts, probs_name)
        has_original_eids = (
            self.edge_attributes is not None
            and ORIGINAL_EDGE_ID in self.edge_attributes
        )
833
        C_sampled_subgraph = self._c_csc_graph.temporal_sample_neighbors(
834
835
836
837
838
839
840
841
842
            nodes,
            input_nodes_timestamp,
            fanouts.tolist(),
            replace,
            has_original_eids,
            probs_name,
            node_timestamp_attr_name,
            edge_timestamp_attr_name,
        )
843
        return self._convert_to_sampled_subgraph(C_sampled_subgraph)
844

845
846
847
848
849
850
851
852
853
854
855
856
    def sample_negative_edges_uniform(
        self, edge_type, node_pairs, negative_ratio
    ):
        """
        Sample negative edges by randomly choosing negative source-destination
        pairs according to a uniform distribution. For each edge ``(u, v)``,
        it is supposed to generate `negative_ratio` pairs of negative edges
        ``(u, v')``, where ``v'`` is chosen uniformly from all the nodes in
        the graph.

        Parameters
        ----------
857
        edge_type: str
858
859
860
            The type of edges in the provided node_pairs. Any negative edges
            sampled will also have the same type. If set to None, it will be
            considered as a homogeneous graph.
861
        node_pairs : Tuple[Tensor, Tensor]
862
863
864
865
866
867
868
869
870
871
            A tuple of two 1D tensors that represent the source and destination
            of positive edges, with 'positive' indicating that these edges are
            present in the graph. It's important to note that within the
            context of a heterogeneous graph, the ids in these tensors signify
            heterogeneous ids.
        negative_ratio: int
            The ratio of the number of negative samples to positive samples.

        Returns
        -------
872
        Tuple[Tensor, Tensor]
873
874
875
876
877
878
879
            A tuple consisting of two 1D tensors represents the source and
            destination of negative edges. In the context of a heterogeneous
            graph, both the input nodes and the selected nodes are represented
            by heterogeneous IDs, and the formed edges are of the input type
            `edge_type`. Note that negative refers to false negatives, which
            means the edge could be present or not present in the graph.
        """
880
        if edge_type is not None:
881
882
883
884
            assert (
                self.node_type_offset is not None
            ), "The 'node_type_offset' array is necessary for performing \
                negative sampling by edge type."
885
            _, _, dst_node_type = etype_str_to_tuple(edge_type)
886
            dst_node_type_id = self.node_type_to_id[dst_node_type]
887
888
889
890
891
            max_node_id = (
                self.node_type_offset[dst_node_type_id + 1]
                - self.node_type_offset[dst_node_type_id]
            )
        else:
892
            max_node_id = self.total_num_nodes
893
894
895
896
897
898
        return self._c_csc_graph.sample_negative_edges_uniform(
            node_pairs,
            negative_ratio,
            max_node_id,
        )

899
900
901
902
903
904
905
906
907
908
    def copy_to_shared_memory(self, shared_memory_name: str):
        """Copy the graph to shared memory.

        Parameters
        ----------
        shared_memory_name : str
            Name of the shared memory.

        Returns
        -------
909
910
        FusedCSCSamplingGraph
            The copied FusedCSCSamplingGraph object on shared memory.
911
        """
912
        return FusedCSCSamplingGraph(
913
914
915
            self._c_csc_graph.copy_to_shared_memory(shared_memory_name),
        )

916
917
918
919
920
921
922
923
924
925
926
    def _apply_to_members(self, fn):
        """Apply passed fn to all members of `FusedCSCSamplingGraph`."""
        self.csc_indptr = recursive_apply(self.csc_indptr, fn)
        self.indices = recursive_apply(self.indices, fn)
        self.node_type_offset = recursive_apply(self.node_type_offset, fn)
        self.type_per_edge = recursive_apply(self.type_per_edge, fn)
        self.node_attributes = recursive_apply(self.node_attributes, fn)
        self.edge_attributes = recursive_apply(self.edge_attributes, fn)

        return self

927
    def to(self, device: torch.device) -> None:  # pylint: disable=invalid-name
928
        """Copy `FusedCSCSamplingGraph` to the specified device."""
929

930
        def _to(x):
931
932
            return x.to(device) if hasattr(x, "to") else x

933
        return self._apply_to_members(_to)
934

935
936
937
938
    def pin_memory_(self):
        """Copy `FusedCSCSamplingGraph` to the pinned memory in-place."""

        def _pin(x):
939
            return x.pin_memory() if hasattr(x, "pin_memory") else x
940
941

        self._apply_to_members(_pin)
942

943

944
def fused_csc_sampling_graph(
945
946
947
948
    csc_indptr: torch.Tensor,
    indices: torch.Tensor,
    node_type_offset: Optional[torch.tensor] = None,
    type_per_edge: Optional[torch.tensor] = None,
949
950
    node_type_to_id: Optional[Dict[str, int]] = None,
    edge_type_to_id: Optional[Dict[str, int]] = None,
951
    node_attributes: Optional[Dict[str, torch.tensor]] = None,
952
    edge_attributes: Optional[Dict[str, torch.tensor]] = None,
953
954
) -> FusedCSCSamplingGraph:
    """Create a FusedCSCSamplingGraph object from a CSC representation.
955
956
957
958
959

    Parameters
    ----------
    csc_indptr : torch.Tensor
        Pointer to the start of each row in the `indices`. An integer tensor
960
        with shape `(total_num_nodes+1,)`.
961
962
    indices : torch.Tensor
        Column indices of the non-zero elements in the CSC graph. An integer
963
        tensor with shape `(total_num_edges,)`.
964
965
966
967
    node_type_offset : Optional[torch.tensor], optional
        Offset of node types in the graph, by default None.
    type_per_edge : Optional[torch.tensor], optional
        Type ids of each edge in the graph, by default None.
968
969
970
971
    node_type_to_id : Optional[Dict[str, int]], optional
        Map node types to ids, by default None.
    edge_type_to_id : Optional[Dict[str, int]], optional
        Map edge types to ids, by default None.
972
973
    node_attributes: Optional[Dict[str, torch.tensor]], optional
        Node attributes of the graph, by default None.
974
975
    edge_attributes: Optional[Dict[str, torch.tensor]], optional
        Edge attributes of the graph, by default None.
976

977
978
    Returns
    -------
979
980
    FusedCSCSamplingGraph
        The created FusedCSCSamplingGraph object.
981
982
983
984

    Examples
    --------
    >>> ntypes = {'n1': 0, 'n2': 1, 'n3': 2}
985
    >>> etypes = {'n1:e1:n2': 0, 'n1:e2:n3': 1}
986
987
988
989
    >>> csc_indptr = torch.tensor([0, 2, 5, 7])
    >>> indices = torch.tensor([1, 3, 0, 1, 2, 0, 3])
    >>> node_type_offset = torch.tensor([0, 1, 2, 3])
    >>> type_per_edge = torch.tensor([0, 1, 0, 1, 1, 0, 0])
990
    >>> graph = graphbolt.fused_csc_sampling_graph(csc_indptr, indices,
991
992
    ...             node_type_offset=node_type_offset,
    ...             type_per_edge=type_per_edge,
993
    ...             node_type_to_id=ntypes, edge_type_to_id=etypes,
994
    ...             node_attributes=None, edge_attributes=None,)
995
    >>> print(graph)
996
    FusedCSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]),
997
                     indices=tensor([1, 3, 0, 1, 2, 0, 3]),
998
                     total_num_nodes=3, total_num_edges=7)
999
    """
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
    if node_type_to_id is not None and edge_type_to_id is not None:
        node_types = list(node_type_to_id.keys())
        edge_types = list(edge_type_to_id.keys())
        node_type_ids = list(node_type_to_id.values())
        edge_type_ids = list(edge_type_to_id.values())

        # Validate node_type_to_id.
        assert all(
            isinstance(x, str) for x in node_types
        ), "Node type name should be string."
        assert all(
            isinstance(x, int) for x in node_type_ids
        ), "Node type id should be int."
        assert len(node_type_ids) == len(
            set(node_type_ids)
        ), "Multiple node types shoud not be mapped to a same id."
        # Validate edge_type_to_id.
        for edge_type in edge_types:
            src, edge, dst = etype_str_to_tuple(edge_type)
            assert isinstance(edge, str), "Edge type name should be string."
            assert (
                src in node_types
            ), f"Unrecognized node type {src} in edge type {edge_type}"
            assert (
                dst in node_types
            ), f"Unrecognized node type {dst} in edge type {edge_type}"
        assert all(
            isinstance(x, int) for x in edge_type_ids
        ), "Edge type id should be int."
        assert len(edge_type_ids) == len(
            set(edge_type_ids)
        ), "Multiple edge types shoud not be mapped to a same id."

        if node_type_offset is not None:
            assert len(node_type_to_id) + 1 == node_type_offset.size(
                0
            ), "node_type_offset length should be |ntypes| + 1."
1037
    return FusedCSCSamplingGraph(
1038
        torch.ops.graphbolt.fused_csc_sampling_graph(
1039
1040
1041
1042
            csc_indptr,
            indices,
            node_type_offset,
            type_per_edge,
1043
1044
            node_type_to_id,
            edge_type_to_id,
1045
            node_attributes,
1046
            edge_attributes,
1047
1048
1049
1050
        ),
    )


1051
1052
def load_from_shared_memory(
    shared_memory_name: str,
1053
1054
) -> FusedCSCSamplingGraph:
    """Load a FusedCSCSamplingGraph object from shared memory.
1055
1056
1057
1058
1059
1060
1061
1062

    Parameters
    ----------
    shared_memory_name : str
        Name of the shared memory.

    Returns
    -------
1063
1064
    FusedCSCSamplingGraph
        The loaded FusedCSCSamplingGraph object on shared memory.
1065
    """
1066
    return FusedCSCSamplingGraph(
1067
1068
1069
1070
        torch.ops.graphbolt.load_from_shared_memory(shared_memory_name),
    )


1071
def _csc_sampling_graph_str(graph: FusedCSCSamplingGraph) -> str:
1072
1073
1074
1075
1076
    """Internal function for converting a csc sampling graph to string
    representation.
    """
    csc_indptr_str = str(graph.csc_indptr)
    indices_str = str(graph.indices)
1077
1078
1079
1080
1081
1082
1083
1084
1085
    meta_str = f"num_nodes={graph.total_num_nodes}, num_edges={graph.num_edges}"
    if graph.node_type_offset is not None:
        meta_str += f", node_type_offset={graph.node_type_offset}"
    if graph.type_per_edge is not None:
        meta_str += f", type_per_edge={graph.type_per_edge}"
    if graph.node_type_to_id is not None:
        meta_str += f", node_type_to_id={graph.node_type_to_id}"
    if graph.edge_type_to_id is not None:
        meta_str += f", edge_type_to_id={graph.edge_type_to_id}"
1086
1087
    if graph.node_attributes is not None:
        meta_str += f", node_attributes={graph.node_attributes}"
1088
1089
1090
    if graph.edge_attributes is not None:
        meta_str += f", edge_attributes={graph.edge_attributes}"

1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
    prefix = f"{type(graph).__name__}("

    def _add_indent(_str, indent):
        lines = _str.split("\n")
        lines = [lines[0]] + [" " * indent + line for line in lines[1:]]
        return "\n".join(lines)

    final_str = (
        "csc_indptr="
        + _add_indent(csc_indptr_str, len("csc_indptr="))
        + ",\n"
        + "indices="
        + _add_indent(indices_str, len("indices="))
        + ",\n"
        + meta_str
        + ")"
    )

    final_str = prefix + _add_indent(final_str, len(prefix))
    return final_str
1111
1112


1113
1114
1115
1116
def from_dglgraph(
    g: DGLGraph,
    is_homogeneous: bool = False,
    include_original_edge_id: bool = False,
1117
1118
) -> FusedCSCSamplingGraph:
    """Convert a DGLGraph to FusedCSCSamplingGraph."""
1119

1120
1121
1122
    homo_g, ntype_count, _ = to_homogeneous(
        g, ndata=g.ndata, edata=g.edata, return_count=True
    )
1123
1124

    if is_homogeneous:
1125
1126
        node_type_to_id = None
        edge_type_to_id = None
1127
1128
1129
1130
1131
1132
1133
    else:
        # Initialize metadata.
        node_type_to_id = {ntype: g.get_ntype_id(ntype) for ntype in g.ntypes}
        edge_type_to_id = {
            etype_tuple_to_str(etype): g.get_etype_id(etype)
            for etype in g.canonical_etypes
        }
1134
1135

    # Obtain CSC matrix.
1136
    indptr, indices, edge_ids = homo_g.adj_tensors("csc")
1137
    ntype_count.insert(0, 0)
1138
1139
1140
1141
1142
    node_type_offset = (
        None
        if is_homogeneous
        else torch.cumsum(torch.LongTensor(ntype_count), 0)
    )
1143

1144
    # Assign edge type according to the order of CSC matrix.
1145
1146
1147
1148
1149
    type_per_edge = (
        None
        if is_homogeneous
        else torch.index_select(homo_g.edata[ETYPE], dim=0, index=edge_ids)
    )
1150

1151
    node_attributes = {}
1152
    edge_attributes = {}
1153
1154
1155
1156
1157
1158
    for feat_name, feat_data in homo_g.ndata.items():
        if feat_name not in (NID, NTYPE):
            node_attributes[feat_name] = feat_data
    for feat_name, feat_data in homo_g.edata.items():
        if feat_name not in (EID, ETYPE):
            edge_attributes[feat_name] = feat_data
1159
1160
    if include_original_edge_id:
        # Assign edge attributes according to the original eids mapping.
1161
1162
1163
        edge_attributes[ORIGINAL_EDGE_ID] = torch.index_select(
            homo_g.edata[EID], dim=0, index=edge_ids
        )
1164

1165
    return FusedCSCSamplingGraph(
1166
        torch.ops.graphbolt.fused_csc_sampling_graph(
1167
1168
1169
1170
            indptr,
            indices,
            node_type_offset,
            type_per_edge,
1171
1172
            node_type_to_id,
            edge_type_to_id,
1173
            node_attributes,
1174
            edge_attributes,
1175
1176
        ),
    )