fused_csc_sampling_graph.py 42.7 KB
Newer Older
1
2
"""CSC format sampling graph."""
# pylint: disable= invalid-name
3
from collections import defaultdict
4
from typing import Dict, Optional, Union
5
6
7

import torch

8
9
from dgl.utils import recursive_apply

10
from ...base import EID, ETYPE
11
12
from ...convert import to_homogeneous
from ...heterograph import DGLGraph
13
from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID
14
from ..sampling_graph import SamplingGraph
15
16
17
18
19
from .sampled_subgraph_impl import (
    CSCFormatBase,
    FusedSampledSubgraphImpl,
    SampledSubgraphImpl,
)
20

21

22
__all__ = [
23
    "FusedCSCSamplingGraph",
24
    "fused_csc_sampling_graph",
25
26
27
28
29
    "load_from_shared_memory",
    "from_dglgraph",
]


30
class FusedCSCSamplingGraph(SamplingGraph):
31
    r"""A sampling graph in CSC format."""
32
33
34
35
36

    def __repr__(self):
        return _csc_sampling_graph_str(self)

    def __init__(
37
38
        self,
        c_csc_graph: torch.ScriptObject,
39
    ):
40
        super().__init__()
41
42
43
        self._c_csc_graph = c_csc_graph

    @property
44
    def total_num_nodes(self) -> int:
45
46
47
48
49
50
51
52
53
54
        """Returns the number of nodes in the graph.

        Returns
        -------
        int
            The number of rows in the dense format.
        """
        return self._c_csc_graph.num_nodes()

    @property
55
    def total_num_edges(self) -> int:
56
57
58
59
60
61
62
63
64
        """Returns the number of edges in the graph.

        Returns
        -------
        int
            The number of edges in the graph.
        """
        return self._c_csc_graph.num_edges()

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    @property
    def num_nodes(self) -> Union[int, Dict[str, int]]:
        """The number of nodes in the graph.
        - If the graph is homogenous, returns an integer.
        - If the graph is heterogenous, returns a dictionary.

        Returns
        -------
        Union[int, Dict[str, int]]
            The number of nodes. Integer indicates the total nodes number of a
            homogenous graph; dict indicates nodes number per node types of a
            heterogenous graph.

        Examples
        --------
        >>> import dgl.graphbolt as gb, torch
        >>> total_num_nodes = 5
        >>> total_num_edges = 12
        >>> ntypes = {"N0": 0, "N1": 1}
        >>> etypes = {"N0:R0:N0": 0, "N0:R1:N1": 1,
85
        ...     "N1:R2:N0": 2, "N1:R3:N1": 3}
86
87
88
89
        >>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
        >>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor(
90
        ...     [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
91
        >>> graph = gb.fused_csc_sampling_graph(indptr, indices,
92
93
94
95
        ...     node_type_offset=node_type_offset,
        ...     type_per_edge=type_per_edge,
        ...     node_type_to_id=ntypes,
        ...     edge_type_to_id=etypes)
96
        >>> print(graph.num_nodes)
97
        {'N0': 2, 'N1': 3}
98
99
100
101
102
        """

        offset = self.node_type_offset

        # Homogenous.
103
        if offset is None or self.node_type_to_id is None:
104
105
106
107
108
            return self._c_csc_graph.num_nodes()

        # Heterogenous
        else:
            num_nodes_per_type = {
109
                _type: (offset[_idx + 1] - offset[_idx]).item()
110
                for _type, _idx in self.node_type_to_id.items()
111
112
113
114
            }

            return num_nodes_per_type

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    @property
    def num_edges(self) -> Union[int, Dict[str, int]]:
        """The number of edges in the graph.
        - If the graph is homogenous, returns an integer.
        - If the graph is heterogenous, returns a dictionary.

        Returns
        -------
        Union[int, Dict[str, int]]
            The number of edges. Integer indicates the total edges number of a
            homogenous graph; dict indicates edges number per edge types of a
            heterogenous graph.

        Examples
        --------
        >>> import dgl.graphbolt as gb, torch
        >>> total_num_nodes = 5
        >>> total_num_edges = 12
        >>> ntypes = {"N0": 0, "N1": 1}
        >>> etypes = {"N0:R0:N0": 0, "N0:R1:N1": 1,
        ...     "N1:R2:N0": 2, "N1:R3:N1": 3}
        >>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
        >>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor(
        ...     [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
        >>> metadata = gb.GraphMetadata(ntypes, etypes)
142
        >>> graph = gb.fused_csc_sampling_graph(indptr, indices, node_type_offset,
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        ...     type_per_edge, None, metadata)
        >>> print(graph.num_edges)
        {'N0:R0:N0': 2, 'N0:R1:N1': 1, 'N1:R2:N0': 2, 'N1:R3:N1': 3}
        """

        type_per_edge = self.type_per_edge

        # Homogenous.
        if type_per_edge is None or self.edge_type_to_id is None:
            return self._c_csc_graph.num_edges()

        # Heterogenous
        bincount = torch.bincount(type_per_edge)
        num_edges_per_type = {}
        for etype, etype_id in self.edge_type_to_id.items():
            if etype_id < len(bincount):
                num_edges_per_type[etype] = bincount[etype_id].item()
            else:
                num_edges_per_type[etype] = 0
        return num_edges_per_type

164
165
166
167
168
169
170
171
    @property
    def csc_indptr(self) -> torch.tensor:
        """Returns the indices pointer in the CSC graph.

        Returns
        -------
        torch.tensor
            The indices pointer in the CSC graph. An integer tensor with
172
            shape `(total_num_nodes+1,)`.
173
174
175
        """
        return self._c_csc_graph.csc_indptr()

176
177
178
179
180
    @csc_indptr.setter
    def csc_indptr(self, csc_indptr: torch.tensor) -> None:
        """Sets the indices pointer in the CSC graph."""
        self._c_csc_graph.set_csc_indptr(csc_indptr)

181
182
183
184
185
186
187
188
    @property
    def indices(self) -> torch.tensor:
        """Returns the indices in the CSC graph.

        Returns
        -------
        torch.tensor
            The indices in the CSC graph. An integer tensor with shape
189
            `(total_num_edges,)`.
190
191
192
193
194
195
196
197

        Notes
        -------
        It is assumed that edges of each node are already sorted by edge type
        ids.
        """
        return self._c_csc_graph.indices()

198
199
200
201
202
    @indices.setter
    def indices(self, indices: torch.tensor) -> None:
        """Sets the indices in the CSC graph."""
        self._c_csc_graph.set_indices(indices)

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    @property
    def node_type_offset(self) -> Optional[torch.Tensor]:
        """Returns the node type offset tensor if present.

        Returns
        -------
        torch.Tensor or None
            If present, returns a 1D integer tensor of shape
            `(num_node_types + 1,)`. The tensor is in ascending order as nodes
            of the same type have continuous IDs, and larger node IDs are
            paired with larger node type IDs. The first value is 0 and last
            value is the number of nodes. And nodes with IDs between
            `node_type_offset_[i]~node_type_offset_[i+1]` are of type id 'i'.

        """
        return self._c_csc_graph.node_type_offset()

220
221
222
223
224
225
226
    @node_type_offset.setter
    def node_type_offset(
        self, node_type_offset: Optional[torch.Tensor]
    ) -> None:
        """Sets the node type offset tensor if present."""
        self._c_csc_graph.set_node_type_offset(node_type_offset)

227
228
229
230
231
232
233
    @property
    def type_per_edge(self) -> Optional[torch.Tensor]:
        """Returns the edge type tensor if present.

        Returns
        -------
        torch.Tensor or None
234
            If present, returns a 1D integer tensor of shape (total_num_edges,)
235
236
237
238
            containing the type of each edge in the graph.
        """
        return self._c_csc_graph.type_per_edge()

239
240
241
242
243
    @type_per_edge.setter
    def type_per_edge(self, type_per_edge: Optional[torch.Tensor]) -> None:
        """Sets the edge type tensor if present."""
        self._c_csc_graph.set_type_per_edge(type_per_edge)

244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    @property
    def node_type_to_id(self) -> Optional[Dict[str, int]]:
        """Returns the node type to id dictionary if present.

        Returns
        -------
        Dict[str, int] or None
            If present, returns a dictionary mapping node type to node type
            id.
        """
        return self._c_csc_graph.node_type_to_id()

    @node_type_to_id.setter
    def node_type_to_id(
        self, node_type_to_id: Optional[Dict[str, int]]
    ) -> None:
        """Sets the node type to id dictionary if present."""
        self._c_csc_graph.set_node_type_to_id(node_type_to_id)

    @property
    def edge_type_to_id(self) -> Optional[Dict[str, int]]:
        """Returns the edge type to id dictionary if present.

        Returns
        -------
        Dict[str, int] or None
            If present, returns a dictionary mapping edge type to edge type
            id.
        """
        return self._c_csc_graph.edge_type_to_id()

    @edge_type_to_id.setter
    def edge_type_to_id(
        self, edge_type_to_id: Optional[Dict[str, int]]
    ) -> None:
        """Sets the edge type to id dictionary if present."""
        self._c_csc_graph.set_edge_type_to_id(edge_type_to_id)

282
283
284
285
286
287
    @property
    def edge_attributes(self) -> Optional[Dict[str, torch.Tensor]]:
        """Returns the edge attributes dictionary.

        Returns
        -------
288
        Dict[str, torch.Tensor] or None
289
290
291
292
293
294
295
            If present, returns a dictionary of edge attributes. Each key
            represents the attribute's name, while the corresponding value
            holds the attribute's specific value. The length of each value
            should match the total number of edges."
        """
        return self._c_csc_graph.edge_attributes()

296
297
298
299
300
301
302
    @edge_attributes.setter
    def edge_attributes(
        self, edge_attributes: Optional[Dict[str, torch.Tensor]]
    ) -> None:
        """Sets the edge attributes dictionary."""
        self._c_csc_graph.set_edge_attributes(edge_attributes)

303
304
305
    def in_subgraph(
        self, nodes: Union[torch.Tensor, Dict[str, torch.Tensor]]
    ) -> FusedSampledSubgraphImpl:
306
307
308
        """Return the subgraph induced on the inbound edges of the given nodes.

        An in subgraph is equivalent to creating a new graph using the incoming
309
310
        edges of the given nodes. Subgraph is compacted according to the order
        of passed-in `nodes`.
311
312
313

        Parameters
        ----------
314
315
316
317
318
319
        nodes: torch.Tensor or Dict[str, torch.Tensor]
            IDs of the given seed nodes.
              - If `nodes` is a tensor: It means the graph is homogeneous
                graph, and ids inside are homogeneous ids.
              - If `nodes` is a dictionary: The keys should be node type and
                ids inside are heterogeneous ids.
320
321
322

        Returns
        -------
323
        FusedSampledSubgraphImpl
324
            The in subgraph.
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339

        Examples
        --------
        >>> import dgl.graphbolt as gb
        >>> import torch
        >>> total_num_nodes = 5
        >>> total_num_edges = 12
        >>> ntypes = {"N0": 0, "N1": 1}
        >>> etypes = {
        ...     "N0:R0:N0": 0, "N0:R1:N1": 1, "N1:R2:N0": 2, "N1:R3:N1": 3}
        >>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
        >>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor(
        ...     [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
340
        >>> graph = gb.fused_csc_sampling_graph(indptr, indices,
341
342
343
344
        ...     node_type_offset=node_type_offset,
        ...     type_per_edge=type_per_edge,
        ...     node_type_to_id=ntypes,
        ...     edge_type_to_id=etypes)
345
346
347
348
349
350
351
352
        >>> nodes = {"N0":torch.LongTensor([1]), "N1":torch.LongTensor([1, 2])}
        >>> in_subgraph = graph.in_subgraph(nodes)
        >>> print(in_subgraph.node_pairs)
        defaultdict(<class 'list'>, {
            'N0:R0:N0': (tensor([]), tensor([])),
            'N0:R1:N1': (tensor([1, 0]), tensor([1, 2])),
            'N1:R2:N0': (tensor([0, 1]), tensor([1, 1])),
            'N1:R3:N1': (tensor([0, 1, 2]), tensor([1, 2, 2]))}
353
        """
354
355
        if isinstance(nodes, dict):
            nodes = self._convert_to_homogeneous_nodes(nodes)
356
357
358
359
360
361
        # Ensure nodes is 1-D tensor.
        assert nodes.dim() == 1, "Nodes should be 1-D tensor."
        # Ensure that there are no duplicate nodes.
        assert len(torch.unique(nodes)) == len(
            nodes
        ), "Nodes cannot have duplicate values."
362

363
        _in_subgraph = self._c_csc_graph.in_subgraph(nodes)
364
        return self._convert_to_fused_sampled_subgraph(_in_subgraph)
365

366
    def _convert_to_fused_sampled_subgraph(
367
368
369
        self,
        C_sampled_subgraph: torch.ScriptObject,
    ):
370
        """An internal function used to convert a fused homogeneous sampled
371
        subgraph to general struct 'FusedSampledSubgraphImpl'."""
372
373
374
        column_num = (
            C_sampled_subgraph.indptr[1:] - C_sampled_subgraph.indptr[:-1]
        )
375
        column = C_sampled_subgraph.original_column_node_ids.repeat_interleave(
376
377
378
379
            column_num
        )
        row = C_sampled_subgraph.indices
        type_per_edge = C_sampled_subgraph.type_per_edge
380
381
382
383
384
385
386
387
388
        original_edge_ids = C_sampled_subgraph.original_edge_ids
        has_original_eids = (
            self.edge_attributes is not None
            and ORIGINAL_EDGE_ID in self.edge_attributes
        )
        if has_original_eids:
            original_edge_ids = self.edge_attributes[ORIGINAL_EDGE_ID][
                original_edge_ids
            ]
389
390
391
392
393
394
395
        if type_per_edge is None:
            # The sampled graph is already a homogeneous graph.
            node_pairs = (row, column)
        else:
            # The sampled graph is a fused homogenized graph, which need to be
            # converted to heterogeneous graphs.
            node_pairs = defaultdict(list)
396
            original_hetero_edge_ids = {}
397
            for etype, etype_id in self.edge_type_to_id.items():
398
                src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
399
400
                src_ntype_id = self.node_type_to_id[src_ntype]
                dst_ntype_id = self.node_type_to_id[dst_ntype]
401
402
403
404
405
406
                mask = type_per_edge == etype_id
                hetero_row = row[mask] - self.node_type_offset[src_ntype_id]
                hetero_column = (
                    column[mask] - self.node_type_offset[dst_ntype_id]
                )
                node_pairs[etype] = (hetero_row, hetero_column)
407
408
409
410
                if has_original_eids:
                    original_hetero_edge_ids[etype] = original_edge_ids[mask]
            if has_original_eids:
                original_edge_ids = original_hetero_edge_ids
411
        return FusedSampledSubgraphImpl(
412
413
            node_pairs=node_pairs, original_edge_ids=original_edge_ids
        )
414

415
416
417
    def _convert_to_homogeneous_nodes(self, nodes):
        homogeneous_nodes = []
        for ntype, ids in nodes.items():
418
            ntype_id = self.node_type_to_id[ntype]
419
420
421
            homogeneous_nodes.append(ids + self.node_type_offset[ntype_id])
        return torch.cat(homogeneous_nodes)

422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
    def _convert_to_sampled_subgraph(
        self,
        C_sampled_subgraph: torch.ScriptObject,
    ) -> SampledSubgraphImpl:
        """An internal function used to convert a fused homogeneous sampled
        subgraph to general struct 'SampledSubgraphImpl'."""
        indptr = C_sampled_subgraph.indptr
        indices = C_sampled_subgraph.indices
        type_per_edge = C_sampled_subgraph.type_per_edge
        column = C_sampled_subgraph.original_column_node_ids
        original_edge_ids = C_sampled_subgraph.original_edge_ids

        has_original_eids = (
            self.edge_attributes is not None
            and ORIGINAL_EDGE_ID in self.edge_attributes
        )
        if has_original_eids:
            original_edge_ids = self.edge_attributes[ORIGINAL_EDGE_ID][
                original_edge_ids
            ]
        if type_per_edge is None:
            # The sampled graph is already a homogeneous graph.
            node_pairs = CSCFormatBase(indptr=indptr, indices=indices)
        else:
            # The sampled graph is a fused homogenized graph, which need to be
            # converted to heterogeneous graphs.
            # Pre-calculate the number of each etype
            num = {}
            for etype in type_per_edge:
                num[etype.item()] = num.get(etype.item(), 0) + 1
            # Preallocate
            subgraph_indice_position = {}
            subgraph_indice = {}
            subgraph_indptr = {}
            node_edge_type = defaultdict(list)
            original_hetero_edge_ids = {}
458
            for etype, etype_id in self.edge_type_to_id.items():
459
460
461
462
463
464
465
466
467
468
469
470
                subgraph_indice[etype] = torch.empty(
                    (num.get(etype_id, 0),), dtype=indices.dtype
                )
                if has_original_eids:
                    original_hetero_edge_ids[etype] = torch.empty(
                        (num.get(etype_id, 0),), dtype=original_edge_ids.dtype
                    )
                subgraph_indptr[etype] = [0]
                subgraph_indice_position[etype] = 0
                # Preprocessing saves the type of seed_nodes as the edge type
                # of dst_ntype.
                _, _, dst_ntype = etype_str_to_tuple(etype)
471
                dst_ntype_id = self.node_type_to_id[dst_ntype]
472
473
                node_edge_type[dst_ntype_id].append((etype, etype_id))
            # construct subgraphs
474
            for i, seed in enumerate(column):
475
476
477
478
479
480
481
482
                l = indptr[i].item()
                r = indptr[i + 1].item()
                node_type = (
                    torch.searchsorted(
                        self.node_type_offset, seed, right=True
                    ).item()
                    - 1
                )
483
                for etype, etype_id in node_edge_type[node_type]:
484
                    src_ntype, _, _ = etype_str_to_tuple(etype)
485
                    src_ntype_id = self.node_type_to_id[src_ntype]
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
                    num_edges = torch.searchsorted(
                        type_per_edge[l:r], etype_id, right=True
                    ).item()
                    end = num_edges + l
                    subgraph_indptr[etype].append(
                        subgraph_indptr[etype][-1] + num_edges
                    )
                    offset = subgraph_indice_position[etype]
                    subgraph_indice_position[etype] += num_edges
                    subgraph_indice[etype][offset : offset + num_edges] = (
                        indices[l:end] - self.node_type_offset[src_ntype_id]
                    )
                    if has_original_eids:
                        original_hetero_edge_ids[etype][
                            offset : offset + num_edges
                        ] = original_edge_ids[l:end]
                    l = end
            if has_original_eids:
                original_edge_ids = original_hetero_edge_ids
            node_pairs = {
                etype: CSCFormatBase(
                    indptr=torch.tensor(subgraph_indptr[etype]),
                    indices=subgraph_indice[etype],
                )
510
                for etype in self.edge_type_to_id.keys()
511
512
513
514
515
516
            }
        return SampledSubgraphImpl(
            node_pairs=node_pairs,
            original_edge_ids=original_edge_ids,
        )

517
    def sample_neighbors(
518
519
520
521
522
        self,
        nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
        fanouts: torch.Tensor,
        replace: bool = False,
        probs_name: Optional[str] = None,
523
524
        # TODO: clean up once the migration is done.
        output_cscformat=False,
525
    ) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
526
527
528
529
530
531
532
        """Sample neighboring edges of the given nodes and return the induced
        subgraph.

        Parameters
        ----------
        nodes: torch.Tensor or Dict[str, torch.Tensor]
            IDs of the given seed nodes.
533
534
535
536
              - If `nodes` is a tensor: It means the graph is homogeneous
                graph, and ids inside are homogeneous ids.
              - If `nodes` is a dictionary: The keys should be node type and
                ids inside are heterogeneous ids.
537
538
539
540
541
542
543
544
545
546
        fanouts: torch.Tensor
            The number of edges to be sampled for each node with or without
            considering edge types.
              - When the length is 1, it indicates that the fanout applies to
                all neighbors of the node as a collective, regardless of the
                edge type.
              - Otherwise, the length should equal to the number of edge
                types, and each fanout value corresponds to a specific edge
                type of the nodes.
            The value of each fanout should be >= 0 or = -1.
547
548
549
550
551
              - When the value is -1, all neighbors (with non-zero probability,
                if weighted) will be sampled once regardless of replacement. It
                is equivalent to selecting all neighbors with non-zero
                probability when the fanout is >= the number of neighbors (and
                replace is set to false).
552
553
554
555
556
557
558
              - When the value is a non-negative integer, it serves as a
                minimum threshold for selecting neighbors.
        replace: bool
            Boolean indicating whether the sample is preformed with or
            without replacement. If True, a value can be selected multiple
            times. Otherwise, each value can be selected only once.
        probs_name: str, optional
559
560
            An optional string specifying the name of an edge attribute used.
            This attribute tensor should contain (unnormalized) probabilities
561
562
563
            corresponding to each neighboring edge of a node. It must be a 1D
            floating-point or boolean tensor, with the number of elements
            equalling the total number of edges.
564

565
566
        Returns
        -------
567
        FusedSampledSubgraphImpl
568
569
570
571
572
            The sampled subgraph.

        Examples
        --------
        >>> import dgl.graphbolt as gb
573
574
575
576
577
578
579
        >>> import torch
        >>> ntypes = {"n1": 0, "n2": 1}
        >>> etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
        >>> indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
        >>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
580
        >>> graph = gb.fused_csc_sampling_graph(indptr, indices,
581
582
583
584
        ...     node_type_offset=node_type_offset,
        ...     type_per_edge=type_per_edge,
        ...     node_type_to_id=ntypes,
        ...     edge_type_to_id=etypes)
585
        >>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
586
587
588
        >>> fanouts = torch.tensor([1, 1])
        >>> subgraph = graph.sample_neighbors(nodes, fanouts)
        >>> print(subgraph.node_pairs)
589
590
        defaultdict(<class 'list'>, {'n1:e1:n2': (tensor([0]),
          tensor([0])), 'n2:e2:n1': (tensor([2]), tensor([0]))})
591
592
        """
        if isinstance(nodes, dict):
593
            nodes = self._convert_to_homogeneous_nodes(nodes)
594
595

        C_sampled_subgraph = self._sample_neighbors(
596
            nodes, fanouts, replace, probs_name
597
        )
598
        if not output_cscformat:
599
600
601
            return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph)
        else:
            return self._convert_to_sampled_subgraph(C_sampled_subgraph)
602

603
604
605
606
    def _check_sampler_arguments(self, nodes, fanouts, probs_name):
        assert nodes.dim() == 1, "Nodes should be 1-D tensor."
        assert fanouts.dim() == 1, "Fanouts should be 1-D tensor."
        expected_fanout_len = 1
607
608
        if self.edge_type_to_id:
            expected_fanout_len = len(self.edge_type_to_id)
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
        assert len(fanouts) in [
            expected_fanout_len,
            1,
        ], "Fanouts should have the same number of elements as etypes or \
            should have a length of 1."
        if fanouts.size(0) > 1:
            assert (
                self.type_per_edge is not None
            ), "To perform sampling for each edge type (when the length of \
                `fanouts` > 1), the graph must include edge type information."
        assert torch.all(
            (fanouts >= 0) | (fanouts == -1)
        ), "Fanouts should consist of values that are either -1 or \
            greater than or equal to 0."
        if probs_name:
            assert (
                probs_name in self.edge_attributes
            ), f"Unknown edge attribute '{probs_name}'."
            probs_or_mask = self.edge_attributes[probs_name]
            assert probs_or_mask.dim() == 1, "Probs should be 1-D tensor."
            assert (
630
                probs_or_mask.size(0) == self.total_num_edges
631
632
633
634
635
636
637
638
639
640
            ), "Probs should have the same number of elements as the number \
                of edges."
            assert probs_or_mask.dtype in [
                torch.bool,
                torch.float16,
                torch.bfloat16,
                torch.float32,
                torch.float64,
            ], "Probs should have a floating-point or boolean data type."

641
    def _sample_neighbors(
642
643
        self,
        nodes: torch.Tensor,
644
        fanouts: torch.Tensor,
645
        replace: bool = False,
646
        probs_name: Optional[str] = None,
647
    ) -> torch.ScriptObject:
648
649
650
651
652
653
654
        """Sample neighboring edges of the given nodes and return the induced
        subgraph.

        Parameters
        ----------
        nodes: torch.Tensor
            IDs of the given seed nodes.
655
656
657
658
        fanouts: torch.Tensor
            The number of edges to be sampled for each node with or without
            considering edge types.
              - When the length is 1, it indicates that the fanout applies to
659
660
                all neighbors of the node as a collective, regardless of the
                edge type.
661
              - Otherwise, the length should equal to the number of edge
662
663
                types, and each fanout value corresponds to a specific edge
                type of the nodes.
664
            The value of each fanout should be >= 0 or = -1.
665
666
667
668
669
              - When the value is -1, all neighbors (with non-zero probability,
                if weighted) will be sampled once regardless of replacement. It
                is equivalent to selecting all neighbors with non-zero
                probability when the fanout is >= the number of neighbors (and
                replace is set to false).
670
              - When the value is a non-negative integer, it serves as a
671
672
                minimum threshold for selecting neighbors.
        replace: bool
673
674
675
            Boolean indicating whether the sample is preformed with or
            without replacement. If True, a value can be selected multiple
            times. Otherwise, each value can be selected only once.
676
677
678
679
680
681
        probs_name: str, optional
            An optional string specifying the name of an edge attribute. This
            attribute tensor should contain (unnormalized) probabilities
            corresponding to each neighboring edge of a node. It must be a 1D
            floating-point or boolean tensor, with the number of elements
            equalling the total number of edges.
682

683
        Returns
Rhett Ying's avatar
Rhett Ying committed
684
        -------
685
        torch.classes.graphbolt.SampledSubgraph
686
            The sampled C subgraph.
687
688
        """
        # Ensure nodes is 1-D tensor.
689
        self._check_sampler_arguments(nodes, fanouts, probs_name)
690
        has_original_eids = (
691
692
693
            self.edge_attributes is not None
            and ORIGINAL_EDGE_ID in self.edge_attributes
        )
694
        return self._c_csc_graph.sample_neighbors(
695
696
697
698
699
700
            nodes,
            fanouts.tolist(),
            replace,
            False,
            has_original_eids,
            probs_name,
701
702
703
704
705
706
707
708
        )

    def sample_layer_neighbors(
        self,
        nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
        fanouts: torch.Tensor,
        replace: bool = False,
        probs_name: Optional[str] = None,
709
710
        # TODO: clean up once the migration is done.
        output_cscformat=False,
711
    ) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
712
        """Sample neighboring edges of the given nodes and return the induced
713
714
715
        subgraph via layer-neighbor sampling from the NeurIPS 2023 paper
        `Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs
        <https://arxiv.org/abs/2210.13339>`__
716
717
718
719
720

        Parameters
        ----------
        nodes: torch.Tensor or Dict[str, torch.Tensor]
            IDs of the given seed nodes.
721
722
723
724
              - If `nodes` is a tensor: It means the graph is homogeneous
                graph, and ids inside are homogeneous ids.
              - If `nodes` is a dictionary: The keys should be node type and
                ids inside are heterogeneous ids.
725
726
727
728
729
730
731
732
733
734
        fanouts: torch.Tensor
            The number of edges to be sampled for each node with or without
            considering edge types.
              - When the length is 1, it indicates that the fanout applies to
                all neighbors of the node as a collective, regardless of the
                edge type.
              - Otherwise, the length should equal to the number of edge
                types, and each fanout value corresponds to a specific edge
                type of the nodes.
            The value of each fanout should be >= 0 or = -1.
735
736
737
738
739
              - When the value is -1, all neighbors (with non-zero probability,
                if weighted) will be sampled once regardless of replacement. It
                is equivalent to selecting all neighbors with non-zero
                probability when the fanout is >= the number of neighbors (and
                replace is set to false).
740
741
742
743
744
745
746
747
748
749
750
751
              - When the value is a non-negative integer, it serves as a
                minimum threshold for selecting neighbors.
        replace: bool
            Boolean indicating whether the sample is preformed with or
            without replacement. If True, a value can be selected multiple
            times. Otherwise, each value can be selected only once.
        probs_name: str, optional
            An optional string specifying the name of an edge attribute. This
            attribute tensor should contain (unnormalized) probabilities
            corresponding to each neighboring edge of a node. It must be a 1D
            floating-point or boolean tensor, with the number of elements
            equalling the total number of edges.
752

753
754
        Returns
        -------
755
        FusedSampledSubgraphImpl
756
757
758
759
            The sampled subgraph.

        Examples
        --------
760
761
762
763
764
765
766
767
        >>> import dgl.graphbolt as gb
        >>> import torch
        >>> ntypes = {"n1": 0, "n2": 1}
        >>> etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
        >>> indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
        >>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
768
        >>> graph = gb.fused_csc_sampling_graph(indptr, indices,
769
770
771
772
        ...     node_type_offset=node_type_offset,
        ...     type_per_edge=type_per_edge,
        ...     node_type_to_id=ntypes,
        ...     edge_type_to_id=etypes)
773
774
775
776
777
778
        >>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
        >>> fanouts = torch.tensor([1, 1])
        >>> subgraph = graph.sample_layer_neighbors(nodes, fanouts)
        >>> print(subgraph.node_pairs)
        defaultdict(<class 'list'>, {'n1:e1:n2': (tensor([1]),
          tensor([0])), 'n2:e2:n1': (tensor([2]), tensor([0]))})
779
780
781
782
783
        """
        if isinstance(nodes, dict):
            nodes = self._convert_to_homogeneous_nodes(nodes)

        self._check_sampler_arguments(nodes, fanouts, probs_name)
784
785
786
787
        has_original_eids = (
            self.edge_attributes is not None
            and ORIGINAL_EDGE_ID in self.edge_attributes
        )
788
        C_sampled_subgraph = self._c_csc_graph.sample_neighbors(
789
790
791
792
793
794
            nodes,
            fanouts.tolist(),
            replace,
            True,
            has_original_eids,
            probs_name,
795
        )
796

797
        if not output_cscformat:
798
799
800
            return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph)
        else:
            return self._convert_to_sampled_subgraph(C_sampled_subgraph)
801

802
803
804
805
806
807
808
809
810
811
812
813
    def sample_negative_edges_uniform(
        self, edge_type, node_pairs, negative_ratio
    ):
        """
        Sample negative edges by randomly choosing negative source-destination
        pairs according to a uniform distribution. For each edge ``(u, v)``,
        it is supposed to generate `negative_ratio` pairs of negative edges
        ``(u, v')``, where ``v'`` is chosen uniformly from all the nodes in
        the graph.

        Parameters
        ----------
814
        edge_type: str
815
816
817
            The type of edges in the provided node_pairs. Any negative edges
            sampled will also have the same type. If set to None, it will be
            considered as a homogeneous graph.
818
        node_pairs : Tuple[Tensor, Tensor]
819
820
821
822
823
824
825
826
827
828
            A tuple of two 1D tensors that represent the source and destination
            of positive edges, with 'positive' indicating that these edges are
            present in the graph. It's important to note that within the
            context of a heterogeneous graph, the ids in these tensors signify
            heterogeneous ids.
        negative_ratio: int
            The ratio of the number of negative samples to positive samples.

        Returns
        -------
829
        Tuple[Tensor, Tensor]
830
831
832
833
834
835
836
            A tuple consisting of two 1D tensors represents the source and
            destination of negative edges. In the context of a heterogeneous
            graph, both the input nodes and the selected nodes are represented
            by heterogeneous IDs, and the formed edges are of the input type
            `edge_type`. Note that negative refers to false negatives, which
            means the edge could be present or not present in the graph.
        """
837
        if edge_type is not None:
838
839
840
841
            assert (
                self.node_type_offset is not None
            ), "The 'node_type_offset' array is necessary for performing \
                negative sampling by edge type."
842
            _, _, dst_node_type = etype_str_to_tuple(edge_type)
843
            dst_node_type_id = self.node_type_to_id[dst_node_type]
844
845
846
847
848
            max_node_id = (
                self.node_type_offset[dst_node_type_id + 1]
                - self.node_type_offset[dst_node_type_id]
            )
        else:
849
            max_node_id = self.total_num_nodes
850
851
852
853
854
855
        return self._c_csc_graph.sample_negative_edges_uniform(
            node_pairs,
            negative_ratio,
            max_node_id,
        )

856
857
858
859
860
861
862
863
864
865
    def copy_to_shared_memory(self, shared_memory_name: str):
        """Copy the graph to shared memory.

        Parameters
        ----------
        shared_memory_name : str
            Name of the shared memory.

        Returns
        -------
866
867
        FusedCSCSamplingGraph
            The copied FusedCSCSamplingGraph object on shared memory.
868
        """
869
        return FusedCSCSamplingGraph(
870
871
872
            self._c_csc_graph.copy_to_shared_memory(shared_memory_name),
        )

873
    def to(self, device: torch.device) -> None:  # pylint: disable=invalid-name
874
        """Copy `FusedCSCSamplingGraph` to the specified device."""
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894

        def _to(x, device):
            return x.to(device) if hasattr(x, "to") else x

        self.csc_indptr = recursive_apply(
            self.csc_indptr, lambda x: _to(x, device)
        )
        self.indices = recursive_apply(self.indices, lambda x: _to(x, device))
        self.node_type_offset = recursive_apply(
            self.node_type_offset, lambda x: _to(x, device)
        )
        self.type_per_edge = recursive_apply(
            self.type_per_edge, lambda x: _to(x, device)
        )
        self.edge_attributes = recursive_apply(
            self.edge_attributes, lambda x: _to(x, device)
        )

        return self

895

896
def fused_csc_sampling_graph(
897
898
899
900
    csc_indptr: torch.Tensor,
    indices: torch.Tensor,
    node_type_offset: Optional[torch.tensor] = None,
    type_per_edge: Optional[torch.tensor] = None,
901
902
    node_type_to_id: Optional[Dict[str, int]] = None,
    edge_type_to_id: Optional[Dict[str, int]] = None,
903
    edge_attributes: Optional[Dict[str, torch.tensor]] = None,
904
905
) -> FusedCSCSamplingGraph:
    """Create a FusedCSCSamplingGraph object from a CSC representation.
906
907
908
909
910

    Parameters
    ----------
    csc_indptr : torch.Tensor
        Pointer to the start of each row in the `indices`. An integer tensor
911
        with shape `(total_num_nodes+1,)`.
912
913
    indices : torch.Tensor
        Column indices of the non-zero elements in the CSC graph. An integer
914
        tensor with shape `(total_num_edges,)`.
915
916
917
918
    node_type_offset : Optional[torch.tensor], optional
        Offset of node types in the graph, by default None.
    type_per_edge : Optional[torch.tensor], optional
        Type ids of each edge in the graph, by default None.
919
920
921
922
    node_type_to_id : Optional[Dict[str, int]], optional
        Map node types to ids, by default None.
    edge_type_to_id : Optional[Dict[str, int]], optional
        Map edge types to ids, by default None.
923
924
    edge_attributes: Optional[Dict[str, torch.tensor]], optional
        Edge attributes of the graph, by default None.
925

926
927
    Returns
    -------
928
929
    FusedCSCSamplingGraph
        The created FusedCSCSamplingGraph object.
930
931
932
933

    Examples
    --------
    >>> ntypes = {'n1': 0, 'n2': 1, 'n3': 2}
934
    >>> etypes = {'n1:e1:n2': 0, 'n1:e2:n3': 1}
935
936
937
938
    >>> csc_indptr = torch.tensor([0, 2, 5, 7])
    >>> indices = torch.tensor([1, 3, 0, 1, 2, 0, 3])
    >>> node_type_offset = torch.tensor([0, 1, 2, 3])
    >>> type_per_edge = torch.tensor([0, 1, 0, 1, 1, 0, 0])
939
    >>> graph = graphbolt.fused_csc_sampling_graph(csc_indptr, indices,
940
941
    ...             node_type_offset=node_type_offset,
    ...             type_per_edge=type_per_edge,
942
943
    ...             node_type_to_id=ntypes, edge_type_to_id=etypes,
    ...             edge_attributes=None,)
944
    >>> print(graph)
945
    FusedCSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]),
946
                     indices=tensor([1, 3, 0, 1, 2, 0, 3]),
947
                     total_num_nodes=3, total_num_edges=7)
948
    """
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
    if node_type_to_id is not None and edge_type_to_id is not None:
        node_types = list(node_type_to_id.keys())
        edge_types = list(edge_type_to_id.keys())
        node_type_ids = list(node_type_to_id.values())
        edge_type_ids = list(edge_type_to_id.values())

        # Validate node_type_to_id.
        assert all(
            isinstance(x, str) for x in node_types
        ), "Node type name should be string."
        assert all(
            isinstance(x, int) for x in node_type_ids
        ), "Node type id should be int."
        assert len(node_type_ids) == len(
            set(node_type_ids)
        ), "Multiple node types shoud not be mapped to a same id."
        # Validate edge_type_to_id.
        for edge_type in edge_types:
            src, edge, dst = etype_str_to_tuple(edge_type)
            assert isinstance(edge, str), "Edge type name should be string."
            assert (
                src in node_types
            ), f"Unrecognized node type {src} in edge type {edge_type}"
            assert (
                dst in node_types
            ), f"Unrecognized node type {dst} in edge type {edge_type}"
        assert all(
            isinstance(x, int) for x in edge_type_ids
        ), "Edge type id should be int."
        assert len(edge_type_ids) == len(
            set(edge_type_ids)
        ), "Multiple edge types shoud not be mapped to a same id."

        if node_type_offset is not None:
            assert len(node_type_to_id) + 1 == node_type_offset.size(
                0
            ), "node_type_offset length should be |ntypes| + 1."
986
    return FusedCSCSamplingGraph(
987
        torch.ops.graphbolt.fused_csc_sampling_graph(
988
989
990
991
            csc_indptr,
            indices,
            node_type_offset,
            type_per_edge,
992
993
            node_type_to_id,
            edge_type_to_id,
994
            edge_attributes,
995
996
997
998
        ),
    )


999
1000
def load_from_shared_memory(
    shared_memory_name: str,
1001
1002
) -> FusedCSCSamplingGraph:
    """Load a FusedCSCSamplingGraph object from shared memory.
1003
1004
1005
1006
1007
1008
1009
1010

    Parameters
    ----------
    shared_memory_name : str
        Name of the shared memory.

    Returns
    -------
1011
1012
    FusedCSCSamplingGraph
        The loaded FusedCSCSamplingGraph object on shared memory.
1013
    """
1014
    return FusedCSCSamplingGraph(
1015
1016
1017
1018
        torch.ops.graphbolt.load_from_shared_memory(shared_memory_name),
    )


1019
def _csc_sampling_graph_str(graph: FusedCSCSamplingGraph) -> str:
1020
1021
1022
1023
1024
    """Internal function for converting a csc sampling graph to string
    representation.
    """
    csc_indptr_str = str(graph.csc_indptr)
    indices_str = str(graph.indices)
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
    meta_str = f"num_nodes={graph.total_num_nodes}, num_edges={graph.num_edges}"
    if graph.node_type_offset is not None:
        meta_str += f", node_type_offset={graph.node_type_offset}"
    if graph.type_per_edge is not None:
        meta_str += f", type_per_edge={graph.type_per_edge}"
    if graph.node_type_to_id is not None:
        meta_str += f", node_type_to_id={graph.node_type_to_id}"
    if graph.edge_type_to_id is not None:
        meta_str += f", edge_type_to_id={graph.edge_type_to_id}"
    if graph.edge_attributes is not None:
        meta_str += f", edge_attributes={graph.edge_attributes}"

1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
    prefix = f"{type(graph).__name__}("

    def _add_indent(_str, indent):
        lines = _str.split("\n")
        lines = [lines[0]] + [" " * indent + line for line in lines[1:]]
        return "\n".join(lines)

    final_str = (
        "csc_indptr="
        + _add_indent(csc_indptr_str, len("csc_indptr="))
        + ",\n"
        + "indices="
        + _add_indent(indices_str, len("indices="))
        + ",\n"
        + meta_str
        + ")"
    )

    final_str = prefix + _add_indent(final_str, len(prefix))
    return final_str
1057
1058


1059
1060
1061
1062
def from_dglgraph(
    g: DGLGraph,
    is_homogeneous: bool = False,
    include_original_edge_id: bool = False,
1063
1064
) -> FusedCSCSamplingGraph:
    """Convert a DGLGraph to FusedCSCSamplingGraph."""
1065

1066
    homo_g, ntype_count, _ = to_homogeneous(g, return_count=True)
1067
1068

    if is_homogeneous:
1069
1070
        node_type_to_id = None
        edge_type_to_id = None
1071
1072
1073
1074
1075
1076
1077
    else:
        # Initialize metadata.
        node_type_to_id = {ntype: g.get_ntype_id(ntype) for ntype in g.ntypes}
        edge_type_to_id = {
            etype_tuple_to_str(etype): g.get_etype_id(etype)
            for etype in g.canonical_etypes
        }
1078
1079

    # Obtain CSC matrix.
1080
    indptr, indices, edge_ids = homo_g.adj_tensors("csc")
1081
    ntype_count.insert(0, 0)
1082
1083
1084
1085
1086
    node_type_offset = (
        None
        if is_homogeneous
        else torch.cumsum(torch.LongTensor(ntype_count), 0)
    )
1087

1088
1089
    # Assign edge type according to the order of CSC matrix.
    type_per_edge = None if is_homogeneous else homo_g.edata[ETYPE][edge_ids]
1090

1091
1092
1093
1094
    edge_attributes = {}
    if include_original_edge_id:
        # Assign edge attributes according to the original eids mapping.
        edge_attributes[ORIGINAL_EDGE_ID] = homo_g.edata[EID][edge_ids]
1095

1096
    return FusedCSCSamplingGraph(
1097
        torch.ops.graphbolt.fused_csc_sampling_graph(
1098
1099
1100
1101
            indptr,
            indices,
            node_type_offset,
            type_per_edge,
1102
1103
            node_type_to_id,
            edge_type_to_id,
1104
            edge_attributes,
1105
1106
        ),
    )