"docs/vscode:/vscode.git/clone" did not exist on "709cf554f69cd40c310a9bdb52a8d85dfc64c274"
fused_csc_sampling_graph.py 44.2 KB
Newer Older
1
2
"""CSC format sampling graph."""
# pylint: disable= invalid-name
3
from collections import defaultdict
4
from typing import Dict, Optional, Union
5
6
7

import torch

8
9
from dgl.utils import recursive_apply

10
from ...base import EID, ETYPE
11
12
from ...convert import to_homogeneous
from ...heterograph import DGLGraph
13
from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID
14
from ..sampling_graph import SamplingGraph
15
16
17
18
19
from .sampled_subgraph_impl import (
    CSCFormatBase,
    FusedSampledSubgraphImpl,
    SampledSubgraphImpl,
)
20

21

22
__all__ = [
23
    "FusedCSCSamplingGraph",
24
    "fused_csc_sampling_graph",
25
26
27
28
29
    "load_from_shared_memory",
    "from_dglgraph",
]


30
class FusedCSCSamplingGraph(SamplingGraph):
31
    r"""A sampling graph in CSC format."""
32
33
34
35
36

    def __repr__(self):
        return _csc_sampling_graph_str(self)

    def __init__(
37
38
        self,
        c_csc_graph: torch.ScriptObject,
39
    ):
40
        super().__init__()
41
42
43
        self._c_csc_graph = c_csc_graph

    @property
44
    def total_num_nodes(self) -> int:
45
46
47
48
49
50
51
52
53
54
        """Returns the number of nodes in the graph.

        Returns
        -------
        int
            The number of rows in the dense format.
        """
        return self._c_csc_graph.num_nodes()

    @property
55
    def total_num_edges(self) -> int:
56
57
58
59
60
61
62
63
64
        """Returns the number of edges in the graph.

        Returns
        -------
        int
            The number of edges in the graph.
        """
        return self._c_csc_graph.num_edges()

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    @property
    def num_nodes(self) -> Union[int, Dict[str, int]]:
        """The number of nodes in the graph.
        - If the graph is homogenous, returns an integer.
        - If the graph is heterogenous, returns a dictionary.

        Returns
        -------
        Union[int, Dict[str, int]]
            The number of nodes. Integer indicates the total nodes number of a
            homogenous graph; dict indicates nodes number per node types of a
            heterogenous graph.

        Examples
        --------
        >>> import dgl.graphbolt as gb, torch
        >>> total_num_nodes = 5
        >>> total_num_edges = 12
        >>> ntypes = {"N0": 0, "N1": 1}
        >>> etypes = {"N0:R0:N0": 0, "N0:R1:N1": 1,
85
        ...     "N1:R2:N0": 2, "N1:R3:N1": 3}
86
87
88
89
        >>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
        >>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor(
90
        ...     [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
91
        >>> graph = gb.fused_csc_sampling_graph(indptr, indices,
92
93
94
95
        ...     node_type_offset=node_type_offset,
        ...     type_per_edge=type_per_edge,
        ...     node_type_to_id=ntypes,
        ...     edge_type_to_id=etypes)
96
        >>> print(graph.num_nodes)
97
        {'N0': 2, 'N1': 3}
98
99
100
101
102
        """

        offset = self.node_type_offset

        # Homogenous.
103
        if offset is None or self.node_type_to_id is None:
104
105
106
107
108
            return self._c_csc_graph.num_nodes()

        # Heterogenous
        else:
            num_nodes_per_type = {
109
                _type: (offset[_idx + 1] - offset[_idx]).item()
110
                for _type, _idx in self.node_type_to_id.items()
111
112
113
114
            }

            return num_nodes_per_type

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    @property
    def num_edges(self) -> Union[int, Dict[str, int]]:
        """The number of edges in the graph.
        - If the graph is homogenous, returns an integer.
        - If the graph is heterogenous, returns a dictionary.

        Returns
        -------
        Union[int, Dict[str, int]]
            The number of edges. Integer indicates the total edges number of a
            homogenous graph; dict indicates edges number per edge types of a
            heterogenous graph.

        Examples
        --------
        >>> import dgl.graphbolt as gb, torch
        >>> total_num_nodes = 5
        >>> total_num_edges = 12
        >>> ntypes = {"N0": 0, "N1": 1}
        >>> etypes = {"N0:R0:N0": 0, "N0:R1:N1": 1,
        ...     "N1:R2:N0": 2, "N1:R3:N1": 3}
        >>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
        >>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor(
        ...     [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
        >>> metadata = gb.GraphMetadata(ntypes, etypes)
142
        >>> graph = gb.fused_csc_sampling_graph(indptr, indices, node_type_offset,
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        ...     type_per_edge, None, metadata)
        >>> print(graph.num_edges)
        {'N0:R0:N0': 2, 'N0:R1:N1': 1, 'N1:R2:N0': 2, 'N1:R3:N1': 3}
        """

        type_per_edge = self.type_per_edge

        # Homogenous.
        if type_per_edge is None or self.edge_type_to_id is None:
            return self._c_csc_graph.num_edges()

        # Heterogenous
        bincount = torch.bincount(type_per_edge)
        num_edges_per_type = {}
        for etype, etype_id in self.edge_type_to_id.items():
            if etype_id < len(bincount):
                num_edges_per_type[etype] = bincount[etype_id].item()
            else:
                num_edges_per_type[etype] = 0
        return num_edges_per_type

164
165
166
167
168
169
170
171
    @property
    def csc_indptr(self) -> torch.tensor:
        """Returns the indices pointer in the CSC graph.

        Returns
        -------
        torch.tensor
            The indices pointer in the CSC graph. An integer tensor with
172
            shape `(total_num_nodes+1,)`.
173
174
175
        """
        return self._c_csc_graph.csc_indptr()

176
177
178
179
180
    @csc_indptr.setter
    def csc_indptr(self, csc_indptr: torch.tensor) -> None:
        """Sets the indices pointer in the CSC graph."""
        self._c_csc_graph.set_csc_indptr(csc_indptr)

181
182
183
184
185
186
187
188
    @property
    def indices(self) -> torch.tensor:
        """Returns the indices in the CSC graph.

        Returns
        -------
        torch.tensor
            The indices in the CSC graph. An integer tensor with shape
189
            `(total_num_edges,)`.
190
191
192
193
194
195
196
197

        Notes
        -------
        It is assumed that edges of each node are already sorted by edge type
        ids.
        """
        return self._c_csc_graph.indices()

198
199
200
201
202
    @indices.setter
    def indices(self, indices: torch.tensor) -> None:
        """Sets the indices in the CSC graph."""
        self._c_csc_graph.set_indices(indices)

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    @property
    def node_type_offset(self) -> Optional[torch.Tensor]:
        """Returns the node type offset tensor if present.

        Returns
        -------
        torch.Tensor or None
            If present, returns a 1D integer tensor of shape
            `(num_node_types + 1,)`. The tensor is in ascending order as nodes
            of the same type have continuous IDs, and larger node IDs are
            paired with larger node type IDs. The first value is 0 and last
            value is the number of nodes. And nodes with IDs between
            `node_type_offset_[i]~node_type_offset_[i+1]` are of type id 'i'.

        """
        return self._c_csc_graph.node_type_offset()

220
221
222
223
224
225
226
    @node_type_offset.setter
    def node_type_offset(
        self, node_type_offset: Optional[torch.Tensor]
    ) -> None:
        """Sets the node type offset tensor if present."""
        self._c_csc_graph.set_node_type_offset(node_type_offset)

227
228
229
230
231
232
233
    @property
    def type_per_edge(self) -> Optional[torch.Tensor]:
        """Returns the edge type tensor if present.

        Returns
        -------
        torch.Tensor or None
234
            If present, returns a 1D integer tensor of shape (total_num_edges,)
235
236
237
238
            containing the type of each edge in the graph.
        """
        return self._c_csc_graph.type_per_edge()

239
240
241
242
243
    @type_per_edge.setter
    def type_per_edge(self, type_per_edge: Optional[torch.Tensor]) -> None:
        """Sets the edge type tensor if present."""
        self._c_csc_graph.set_type_per_edge(type_per_edge)

244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    @property
    def node_type_to_id(self) -> Optional[Dict[str, int]]:
        """Returns the node type to id dictionary if present.

        Returns
        -------
        Dict[str, int] or None
            If present, returns a dictionary mapping node type to node type
            id.
        """
        return self._c_csc_graph.node_type_to_id()

    @node_type_to_id.setter
    def node_type_to_id(
        self, node_type_to_id: Optional[Dict[str, int]]
    ) -> None:
        """Sets the node type to id dictionary if present."""
        self._c_csc_graph.set_node_type_to_id(node_type_to_id)

    @property
    def edge_type_to_id(self) -> Optional[Dict[str, int]]:
        """Returns the edge type to id dictionary if present.

        Returns
        -------
        Dict[str, int] or None
            If present, returns a dictionary mapping edge type to edge type
            id.
        """
        return self._c_csc_graph.edge_type_to_id()

    @edge_type_to_id.setter
    def edge_type_to_id(
        self, edge_type_to_id: Optional[Dict[str, int]]
    ) -> None:
        """Sets the edge type to id dictionary if present."""
        self._c_csc_graph.set_edge_type_to_id(edge_type_to_id)

282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
    @property
    def node_attributes(self) -> Optional[Dict[str, torch.Tensor]]:
        """Returns the node attributes dictionary.

        Returns
        -------
        Dict[str, torch.Tensor] or None
            If present, returns a dictionary of node attributes. Each key
            represents the attribute's name, while the corresponding value
            holds the attribute's specific value. The length of each value
            should match the total number of nodes."
        """
        return self._c_csc_graph.node_attributes()

    @node_attributes.setter
    def node_attributes(
        self, node_attributes: Optional[Dict[str, torch.Tensor]]
    ) -> None:
        """Sets the node attributes dictionary."""
        self._c_csc_graph.set_node_attributes(node_attributes)

303
304
305
306
307
308
    @property
    def edge_attributes(self) -> Optional[Dict[str, torch.Tensor]]:
        """Returns the edge attributes dictionary.

        Returns
        -------
309
        Dict[str, torch.Tensor] or None
310
311
312
313
314
315
316
            If present, returns a dictionary of edge attributes. Each key
            represents the attribute's name, while the corresponding value
            holds the attribute's specific value. The length of each value
            should match the total number of edges."
        """
        return self._c_csc_graph.edge_attributes()

317
318
319
320
321
322
323
    @edge_attributes.setter
    def edge_attributes(
        self, edge_attributes: Optional[Dict[str, torch.Tensor]]
    ) -> None:
        """Sets the edge attributes dictionary."""
        self._c_csc_graph.set_edge_attributes(edge_attributes)

324
    def in_subgraph(
325
326
327
328
329
        self,
        nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
        # TODO: clean up once the migration is done.
        output_cscformat=False,
    ) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
330
331
332
        """Return the subgraph induced on the inbound edges of the given nodes.

        An in subgraph is equivalent to creating a new graph using the incoming
333
334
        edges of the given nodes. Subgraph is compacted according to the order
        of passed-in `nodes`.
335
336
337

        Parameters
        ----------
338
339
340
341
342
343
        nodes: torch.Tensor or Dict[str, torch.Tensor]
            IDs of the given seed nodes.
              - If `nodes` is a tensor: It means the graph is homogeneous
                graph, and ids inside are homogeneous ids.
              - If `nodes` is a dictionary: The keys should be node type and
                ids inside are heterogeneous ids.
344
345
346

        Returns
        -------
347
        FusedSampledSubgraphImpl
348
            The in subgraph.
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363

        Examples
        --------
        >>> import dgl.graphbolt as gb
        >>> import torch
        >>> total_num_nodes = 5
        >>> total_num_edges = 12
        >>> ntypes = {"N0": 0, "N1": 1}
        >>> etypes = {
        ...     "N0:R0:N0": 0, "N0:R1:N1": 1, "N1:R2:N0": 2, "N1:R3:N1": 3}
        >>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
        >>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor(
        ...     [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
364
        >>> graph = gb.fused_csc_sampling_graph(indptr, indices,
365
366
367
368
        ...     node_type_offset=node_type_offset,
        ...     type_per_edge=type_per_edge,
        ...     node_type_to_id=ntypes,
        ...     edge_type_to_id=etypes)
369
370
371
372
373
374
375
376
        >>> nodes = {"N0":torch.LongTensor([1]), "N1":torch.LongTensor([1, 2])}
        >>> in_subgraph = graph.in_subgraph(nodes)
        >>> print(in_subgraph.node_pairs)
        defaultdict(<class 'list'>, {
            'N0:R0:N0': (tensor([]), tensor([])),
            'N0:R1:N1': (tensor([1, 0]), tensor([1, 2])),
            'N1:R2:N0': (tensor([0, 1]), tensor([1, 1])),
            'N1:R3:N1': (tensor([0, 1, 2]), tensor([1, 2, 2]))}
377
        """
378
379
        if isinstance(nodes, dict):
            nodes = self._convert_to_homogeneous_nodes(nodes)
380
381
382
383
384
385
        # Ensure nodes is 1-D tensor.
        assert nodes.dim() == 1, "Nodes should be 1-D tensor."
        # Ensure that there are no duplicate nodes.
        assert len(torch.unique(nodes)) == len(
            nodes
        ), "Nodes cannot have duplicate values."
386

387
        _in_subgraph = self._c_csc_graph.in_subgraph(nodes)
388
389
390
391
        if not output_cscformat:
            return self._convert_to_fused_sampled_subgraph(_in_subgraph)
        else:
            return self._convert_to_sampled_subgraph(_in_subgraph)
392

393
    def _convert_to_fused_sampled_subgraph(
394
395
396
        self,
        C_sampled_subgraph: torch.ScriptObject,
    ):
397
        """An internal function used to convert a fused homogeneous sampled
398
        subgraph to general struct 'FusedSampledSubgraphImpl'."""
399
400
401
        column_num = (
            C_sampled_subgraph.indptr[1:] - C_sampled_subgraph.indptr[:-1]
        )
402
        column = C_sampled_subgraph.original_column_node_ids.repeat_interleave(
403
404
405
406
            column_num
        )
        row = C_sampled_subgraph.indices
        type_per_edge = C_sampled_subgraph.type_per_edge
407
408
409
410
411
412
413
414
415
        original_edge_ids = C_sampled_subgraph.original_edge_ids
        has_original_eids = (
            self.edge_attributes is not None
            and ORIGINAL_EDGE_ID in self.edge_attributes
        )
        if has_original_eids:
            original_edge_ids = self.edge_attributes[ORIGINAL_EDGE_ID][
                original_edge_ids
            ]
416
417
418
419
420
421
422
        if type_per_edge is None:
            # The sampled graph is already a homogeneous graph.
            node_pairs = (row, column)
        else:
            # The sampled graph is a fused homogenized graph, which need to be
            # converted to heterogeneous graphs.
            node_pairs = defaultdict(list)
423
            original_hetero_edge_ids = {}
424
            for etype, etype_id in self.edge_type_to_id.items():
425
                src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
426
427
                src_ntype_id = self.node_type_to_id[src_ntype]
                dst_ntype_id = self.node_type_to_id[dst_ntype]
428
429
430
431
432
433
                mask = type_per_edge == etype_id
                hetero_row = row[mask] - self.node_type_offset[src_ntype_id]
                hetero_column = (
                    column[mask] - self.node_type_offset[dst_ntype_id]
                )
                node_pairs[etype] = (hetero_row, hetero_column)
434
435
436
437
                if has_original_eids:
                    original_hetero_edge_ids[etype] = original_edge_ids[mask]
            if has_original_eids:
                original_edge_ids = original_hetero_edge_ids
438
        return FusedSampledSubgraphImpl(
439
440
            node_pairs=node_pairs, original_edge_ids=original_edge_ids
        )
441

442
443
444
    def _convert_to_homogeneous_nodes(self, nodes):
        homogeneous_nodes = []
        for ntype, ids in nodes.items():
445
            ntype_id = self.node_type_to_id[ntype]
446
447
448
            homogeneous_nodes.append(ids + self.node_type_offset[ntype_id])
        return torch.cat(homogeneous_nodes)

449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
    def _convert_to_sampled_subgraph(
        self,
        C_sampled_subgraph: torch.ScriptObject,
    ) -> SampledSubgraphImpl:
        """An internal function used to convert a fused homogeneous sampled
        subgraph to general struct 'SampledSubgraphImpl'."""
        indptr = C_sampled_subgraph.indptr
        indices = C_sampled_subgraph.indices
        type_per_edge = C_sampled_subgraph.type_per_edge
        column = C_sampled_subgraph.original_column_node_ids
        original_edge_ids = C_sampled_subgraph.original_edge_ids

        has_original_eids = (
            self.edge_attributes is not None
            and ORIGINAL_EDGE_ID in self.edge_attributes
        )
        if has_original_eids:
            original_edge_ids = self.edge_attributes[ORIGINAL_EDGE_ID][
                original_edge_ids
            ]
        if type_per_edge is None:
            # The sampled graph is already a homogeneous graph.
            node_pairs = CSCFormatBase(indptr=indptr, indices=indices)
        else:
            # The sampled graph is a fused homogenized graph, which need to be
            # converted to heterogeneous graphs.
            # Pre-calculate the number of each etype
            num = {}
            for etype in type_per_edge:
                num[etype.item()] = num.get(etype.item(), 0) + 1
            # Preallocate
            subgraph_indice_position = {}
            subgraph_indice = {}
            subgraph_indptr = {}
            node_edge_type = defaultdict(list)
            original_hetero_edge_ids = {}
485
            for etype, etype_id in self.edge_type_to_id.items():
486
487
488
489
490
491
492
493
494
495
496
497
                subgraph_indice[etype] = torch.empty(
                    (num.get(etype_id, 0),), dtype=indices.dtype
                )
                if has_original_eids:
                    original_hetero_edge_ids[etype] = torch.empty(
                        (num.get(etype_id, 0),), dtype=original_edge_ids.dtype
                    )
                subgraph_indptr[etype] = [0]
                subgraph_indice_position[etype] = 0
                # Preprocessing saves the type of seed_nodes as the edge type
                # of dst_ntype.
                _, _, dst_ntype = etype_str_to_tuple(etype)
498
                dst_ntype_id = self.node_type_to_id[dst_ntype]
499
500
                node_edge_type[dst_ntype_id].append((etype, etype_id))
            # construct subgraphs
501
            for i, seed in enumerate(column):
502
503
504
505
506
507
508
509
                l = indptr[i].item()
                r = indptr[i + 1].item()
                node_type = (
                    torch.searchsorted(
                        self.node_type_offset, seed, right=True
                    ).item()
                    - 1
                )
510
                for etype, etype_id in node_edge_type[node_type]:
511
                    src_ntype, _, _ = etype_str_to_tuple(etype)
512
                    src_ntype_id = self.node_type_to_id[src_ntype]
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
                    num_edges = torch.searchsorted(
                        type_per_edge[l:r], etype_id, right=True
                    ).item()
                    end = num_edges + l
                    subgraph_indptr[etype].append(
                        subgraph_indptr[etype][-1] + num_edges
                    )
                    offset = subgraph_indice_position[etype]
                    subgraph_indice_position[etype] += num_edges
                    subgraph_indice[etype][offset : offset + num_edges] = (
                        indices[l:end] - self.node_type_offset[src_ntype_id]
                    )
                    if has_original_eids:
                        original_hetero_edge_ids[etype][
                            offset : offset + num_edges
                        ] = original_edge_ids[l:end]
                    l = end
            if has_original_eids:
                original_edge_ids = original_hetero_edge_ids
            node_pairs = {
                etype: CSCFormatBase(
                    indptr=torch.tensor(subgraph_indptr[etype]),
                    indices=subgraph_indice[etype],
                )
537
                for etype in self.edge_type_to_id.keys()
538
539
540
541
542
543
            }
        return SampledSubgraphImpl(
            node_pairs=node_pairs,
            original_edge_ids=original_edge_ids,
        )

544
    def sample_neighbors(
545
546
547
548
549
        self,
        nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
        fanouts: torch.Tensor,
        replace: bool = False,
        probs_name: Optional[str] = None,
550
551
        # TODO: clean up once the migration is done.
        output_cscformat=False,
552
    ) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
553
554
555
556
557
558
559
        """Sample neighboring edges of the given nodes and return the induced
        subgraph.

        Parameters
        ----------
        nodes: torch.Tensor or Dict[str, torch.Tensor]
            IDs of the given seed nodes.
560
561
562
563
              - If `nodes` is a tensor: It means the graph is homogeneous
                graph, and ids inside are homogeneous ids.
              - If `nodes` is a dictionary: The keys should be node type and
                ids inside are heterogeneous ids.
564
565
566
567
568
569
570
571
572
573
        fanouts: torch.Tensor
            The number of edges to be sampled for each node with or without
            considering edge types.
              - When the length is 1, it indicates that the fanout applies to
                all neighbors of the node as a collective, regardless of the
                edge type.
              - Otherwise, the length should equal to the number of edge
                types, and each fanout value corresponds to a specific edge
                type of the nodes.
            The value of each fanout should be >= 0 or = -1.
574
575
576
577
578
              - When the value is -1, all neighbors (with non-zero probability,
                if weighted) will be sampled once regardless of replacement. It
                is equivalent to selecting all neighbors with non-zero
                probability when the fanout is >= the number of neighbors (and
                replace is set to false).
579
580
581
582
583
584
585
              - When the value is a non-negative integer, it serves as a
                minimum threshold for selecting neighbors.
        replace: bool
            Boolean indicating whether the sample is preformed with or
            without replacement. If True, a value can be selected multiple
            times. Otherwise, each value can be selected only once.
        probs_name: str, optional
586
587
            An optional string specifying the name of an edge attribute used.
            This attribute tensor should contain (unnormalized) probabilities
588
589
590
            corresponding to each neighboring edge of a node. It must be a 1D
            floating-point or boolean tensor, with the number of elements
            equalling the total number of edges.
591

592
593
        Returns
        -------
594
        FusedSampledSubgraphImpl
595
596
597
598
599
            The sampled subgraph.

        Examples
        --------
        >>> import dgl.graphbolt as gb
600
601
602
603
604
605
606
        >>> import torch
        >>> ntypes = {"n1": 0, "n2": 1}
        >>> etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
        >>> indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
        >>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
607
        >>> graph = gb.fused_csc_sampling_graph(indptr, indices,
608
609
610
611
        ...     node_type_offset=node_type_offset,
        ...     type_per_edge=type_per_edge,
        ...     node_type_to_id=ntypes,
        ...     edge_type_to_id=etypes)
612
        >>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
613
614
615
        >>> fanouts = torch.tensor([1, 1])
        >>> subgraph = graph.sample_neighbors(nodes, fanouts)
        >>> print(subgraph.node_pairs)
616
617
        defaultdict(<class 'list'>, {'n1:e1:n2': (tensor([0]),
          tensor([0])), 'n2:e2:n1': (tensor([2]), tensor([0]))})
618
619
        """
        if isinstance(nodes, dict):
620
            nodes = self._convert_to_homogeneous_nodes(nodes)
621
622

        C_sampled_subgraph = self._sample_neighbors(
623
            nodes, fanouts, replace, probs_name
624
        )
625
        if not output_cscformat:
626
627
628
            return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph)
        else:
            return self._convert_to_sampled_subgraph(C_sampled_subgraph)
629

630
631
632
633
    def _check_sampler_arguments(self, nodes, fanouts, probs_name):
        assert nodes.dim() == 1, "Nodes should be 1-D tensor."
        assert fanouts.dim() == 1, "Fanouts should be 1-D tensor."
        expected_fanout_len = 1
634
635
        if self.edge_type_to_id:
            expected_fanout_len = len(self.edge_type_to_id)
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
        assert len(fanouts) in [
            expected_fanout_len,
            1,
        ], "Fanouts should have the same number of elements as etypes or \
            should have a length of 1."
        if fanouts.size(0) > 1:
            assert (
                self.type_per_edge is not None
            ), "To perform sampling for each edge type (when the length of \
                `fanouts` > 1), the graph must include edge type information."
        assert torch.all(
            (fanouts >= 0) | (fanouts == -1)
        ), "Fanouts should consist of values that are either -1 or \
            greater than or equal to 0."
        if probs_name:
            assert (
                probs_name in self.edge_attributes
            ), f"Unknown edge attribute '{probs_name}'."
            probs_or_mask = self.edge_attributes[probs_name]
            assert probs_or_mask.dim() == 1, "Probs should be 1-D tensor."
            assert (
657
                probs_or_mask.size(0) == self.total_num_edges
658
659
660
661
662
663
664
665
666
667
            ), "Probs should have the same number of elements as the number \
                of edges."
            assert probs_or_mask.dtype in [
                torch.bool,
                torch.float16,
                torch.bfloat16,
                torch.float32,
                torch.float64,
            ], "Probs should have a floating-point or boolean data type."

668
    def _sample_neighbors(
669
670
        self,
        nodes: torch.Tensor,
671
        fanouts: torch.Tensor,
672
        replace: bool = False,
673
        probs_name: Optional[str] = None,
674
    ) -> torch.ScriptObject:
675
676
677
678
679
680
681
        """Sample neighboring edges of the given nodes and return the induced
        subgraph.

        Parameters
        ----------
        nodes: torch.Tensor
            IDs of the given seed nodes.
682
683
684
685
        fanouts: torch.Tensor
            The number of edges to be sampled for each node with or without
            considering edge types.
              - When the length is 1, it indicates that the fanout applies to
686
687
                all neighbors of the node as a collective, regardless of the
                edge type.
688
              - Otherwise, the length should equal to the number of edge
689
690
                types, and each fanout value corresponds to a specific edge
                type of the nodes.
691
            The value of each fanout should be >= 0 or = -1.
692
693
694
695
696
              - When the value is -1, all neighbors (with non-zero probability,
                if weighted) will be sampled once regardless of replacement. It
                is equivalent to selecting all neighbors with non-zero
                probability when the fanout is >= the number of neighbors (and
                replace is set to false).
697
              - When the value is a non-negative integer, it serves as a
698
699
                minimum threshold for selecting neighbors.
        replace: bool
700
701
702
            Boolean indicating whether the sample is preformed with or
            without replacement. If True, a value can be selected multiple
            times. Otherwise, each value can be selected only once.
703
704
705
706
707
708
        probs_name: str, optional
            An optional string specifying the name of an edge attribute. This
            attribute tensor should contain (unnormalized) probabilities
            corresponding to each neighboring edge of a node. It must be a 1D
            floating-point or boolean tensor, with the number of elements
            equalling the total number of edges.
709

710
        Returns
Rhett Ying's avatar
Rhett Ying committed
711
        -------
712
        torch.classes.graphbolt.SampledSubgraph
713
            The sampled C subgraph.
714
715
        """
        # Ensure nodes is 1-D tensor.
716
        self._check_sampler_arguments(nodes, fanouts, probs_name)
717
        has_original_eids = (
718
719
720
            self.edge_attributes is not None
            and ORIGINAL_EDGE_ID in self.edge_attributes
        )
721
        return self._c_csc_graph.sample_neighbors(
722
723
724
725
726
727
            nodes,
            fanouts.tolist(),
            replace,
            False,
            has_original_eids,
            probs_name,
728
729
730
731
732
733
734
735
        )

    def sample_layer_neighbors(
        self,
        nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
        fanouts: torch.Tensor,
        replace: bool = False,
        probs_name: Optional[str] = None,
736
737
        # TODO: clean up once the migration is done.
        output_cscformat=False,
738
    ) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
739
        """Sample neighboring edges of the given nodes and return the induced
740
741
742
        subgraph via layer-neighbor sampling from the NeurIPS 2023 paper
        `Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs
        <https://arxiv.org/abs/2210.13339>`__
743
744
745
746
747

        Parameters
        ----------
        nodes: torch.Tensor or Dict[str, torch.Tensor]
            IDs of the given seed nodes.
748
749
750
751
              - If `nodes` is a tensor: It means the graph is homogeneous
                graph, and ids inside are homogeneous ids.
              - If `nodes` is a dictionary: The keys should be node type and
                ids inside are heterogeneous ids.
752
753
754
755
756
757
758
759
760
761
        fanouts: torch.Tensor
            The number of edges to be sampled for each node with or without
            considering edge types.
              - When the length is 1, it indicates that the fanout applies to
                all neighbors of the node as a collective, regardless of the
                edge type.
              - Otherwise, the length should equal to the number of edge
                types, and each fanout value corresponds to a specific edge
                type of the nodes.
            The value of each fanout should be >= 0 or = -1.
762
763
764
765
766
              - When the value is -1, all neighbors (with non-zero probability,
                if weighted) will be sampled once regardless of replacement. It
                is equivalent to selecting all neighbors with non-zero
                probability when the fanout is >= the number of neighbors (and
                replace is set to false).
767
768
769
770
771
772
773
774
775
776
777
778
              - When the value is a non-negative integer, it serves as a
                minimum threshold for selecting neighbors.
        replace: bool
            Boolean indicating whether the sample is preformed with or
            without replacement. If True, a value can be selected multiple
            times. Otherwise, each value can be selected only once.
        probs_name: str, optional
            An optional string specifying the name of an edge attribute. This
            attribute tensor should contain (unnormalized) probabilities
            corresponding to each neighboring edge of a node. It must be a 1D
            floating-point or boolean tensor, with the number of elements
            equalling the total number of edges.
779

780
781
        Returns
        -------
782
        FusedSampledSubgraphImpl
783
784
785
786
            The sampled subgraph.

        Examples
        --------
787
788
789
790
791
792
793
794
        >>> import dgl.graphbolt as gb
        >>> import torch
        >>> ntypes = {"n1": 0, "n2": 1}
        >>> etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
        >>> indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
        >>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
795
        >>> graph = gb.fused_csc_sampling_graph(indptr, indices,
796
797
798
799
        ...     node_type_offset=node_type_offset,
        ...     type_per_edge=type_per_edge,
        ...     node_type_to_id=ntypes,
        ...     edge_type_to_id=etypes)
800
801
802
803
804
805
        >>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
        >>> fanouts = torch.tensor([1, 1])
        >>> subgraph = graph.sample_layer_neighbors(nodes, fanouts)
        >>> print(subgraph.node_pairs)
        defaultdict(<class 'list'>, {'n1:e1:n2': (tensor([1]),
          tensor([0])), 'n2:e2:n1': (tensor([2]), tensor([0]))})
806
807
808
809
810
        """
        if isinstance(nodes, dict):
            nodes = self._convert_to_homogeneous_nodes(nodes)

        self._check_sampler_arguments(nodes, fanouts, probs_name)
811
812
813
814
        has_original_eids = (
            self.edge_attributes is not None
            and ORIGINAL_EDGE_ID in self.edge_attributes
        )
815
        C_sampled_subgraph = self._c_csc_graph.sample_neighbors(
816
817
818
819
820
821
            nodes,
            fanouts.tolist(),
            replace,
            True,
            has_original_eids,
            probs_name,
822
        )
823

824
        if not output_cscformat:
825
826
827
            return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph)
        else:
            return self._convert_to_sampled_subgraph(C_sampled_subgraph)
828

829
830
831
832
833
834
835
836
837
838
839
840
    def sample_negative_edges_uniform(
        self, edge_type, node_pairs, negative_ratio
    ):
        """
        Sample negative edges by randomly choosing negative source-destination
        pairs according to a uniform distribution. For each edge ``(u, v)``,
        it is supposed to generate `negative_ratio` pairs of negative edges
        ``(u, v')``, where ``v'`` is chosen uniformly from all the nodes in
        the graph.

        Parameters
        ----------
841
        edge_type: str
842
843
844
            The type of edges in the provided node_pairs. Any negative edges
            sampled will also have the same type. If set to None, it will be
            considered as a homogeneous graph.
845
        node_pairs : Tuple[Tensor, Tensor]
846
847
848
849
850
851
852
853
854
855
            A tuple of two 1D tensors that represent the source and destination
            of positive edges, with 'positive' indicating that these edges are
            present in the graph. It's important to note that within the
            context of a heterogeneous graph, the ids in these tensors signify
            heterogeneous ids.
        negative_ratio: int
            The ratio of the number of negative samples to positive samples.

        Returns
        -------
856
        Tuple[Tensor, Tensor]
857
858
859
860
861
862
863
            A tuple consisting of two 1D tensors represents the source and
            destination of negative edges. In the context of a heterogeneous
            graph, both the input nodes and the selected nodes are represented
            by heterogeneous IDs, and the formed edges are of the input type
            `edge_type`. Note that negative refers to false negatives, which
            means the edge could be present or not present in the graph.
        """
864
        if edge_type is not None:
865
866
867
868
            assert (
                self.node_type_offset is not None
            ), "The 'node_type_offset' array is necessary for performing \
                negative sampling by edge type."
869
            _, _, dst_node_type = etype_str_to_tuple(edge_type)
870
            dst_node_type_id = self.node_type_to_id[dst_node_type]
871
872
873
874
875
            max_node_id = (
                self.node_type_offset[dst_node_type_id + 1]
                - self.node_type_offset[dst_node_type_id]
            )
        else:
876
            max_node_id = self.total_num_nodes
877
878
879
880
881
882
        return self._c_csc_graph.sample_negative_edges_uniform(
            node_pairs,
            negative_ratio,
            max_node_id,
        )

883
884
885
886
887
888
889
890
891
892
    def copy_to_shared_memory(self, shared_memory_name: str):
        """Copy the graph to shared memory.

        Parameters
        ----------
        shared_memory_name : str
            Name of the shared memory.

        Returns
        -------
893
894
        FusedCSCSamplingGraph
            The copied FusedCSCSamplingGraph object on shared memory.
895
        """
896
        return FusedCSCSamplingGraph(
897
898
899
            self._c_csc_graph.copy_to_shared_memory(shared_memory_name),
        )

900
    def to(self, device: torch.device) -> None:  # pylint: disable=invalid-name
901
        """Copy `FusedCSCSamplingGraph` to the specified device."""
902
903
904
905
906
907
908
909
910
911
912
913
914
915

        def _to(x, device):
            return x.to(device) if hasattr(x, "to") else x

        self.csc_indptr = recursive_apply(
            self.csc_indptr, lambda x: _to(x, device)
        )
        self.indices = recursive_apply(self.indices, lambda x: _to(x, device))
        self.node_type_offset = recursive_apply(
            self.node_type_offset, lambda x: _to(x, device)
        )
        self.type_per_edge = recursive_apply(
            self.type_per_edge, lambda x: _to(x, device)
        )
916
917
918
        self.node_attributes = recursive_apply(
            self.node_attributes, lambda x: _to(x, device)
        )
919
920
921
922
923
924
        self.edge_attributes = recursive_apply(
            self.edge_attributes, lambda x: _to(x, device)
        )

        return self

925

926
def fused_csc_sampling_graph(
927
928
929
930
    csc_indptr: torch.Tensor,
    indices: torch.Tensor,
    node_type_offset: Optional[torch.tensor] = None,
    type_per_edge: Optional[torch.tensor] = None,
931
932
    node_type_to_id: Optional[Dict[str, int]] = None,
    edge_type_to_id: Optional[Dict[str, int]] = None,
933
    node_attributes: Optional[Dict[str, torch.tensor]] = None,
934
    edge_attributes: Optional[Dict[str, torch.tensor]] = None,
935
936
) -> FusedCSCSamplingGraph:
    """Create a FusedCSCSamplingGraph object from a CSC representation.
937
938
939
940
941

    Parameters
    ----------
    csc_indptr : torch.Tensor
        Pointer to the start of each row in the `indices`. An integer tensor
942
        with shape `(total_num_nodes+1,)`.
943
944
    indices : torch.Tensor
        Column indices of the non-zero elements in the CSC graph. An integer
945
        tensor with shape `(total_num_edges,)`.
946
947
948
949
    node_type_offset : Optional[torch.tensor], optional
        Offset of node types in the graph, by default None.
    type_per_edge : Optional[torch.tensor], optional
        Type ids of each edge in the graph, by default None.
950
951
952
953
    node_type_to_id : Optional[Dict[str, int]], optional
        Map node types to ids, by default None.
    edge_type_to_id : Optional[Dict[str, int]], optional
        Map edge types to ids, by default None.
954
955
    node_attributes: Optional[Dict[str, torch.tensor]], optional
        Node attributes of the graph, by default None.
956
957
    edge_attributes: Optional[Dict[str, torch.tensor]], optional
        Edge attributes of the graph, by default None.
958

959
960
    Returns
    -------
961
962
    FusedCSCSamplingGraph
        The created FusedCSCSamplingGraph object.
963
964
965
966

    Examples
    --------
    >>> ntypes = {'n1': 0, 'n2': 1, 'n3': 2}
967
    >>> etypes = {'n1:e1:n2': 0, 'n1:e2:n3': 1}
968
969
970
971
    >>> csc_indptr = torch.tensor([0, 2, 5, 7])
    >>> indices = torch.tensor([1, 3, 0, 1, 2, 0, 3])
    >>> node_type_offset = torch.tensor([0, 1, 2, 3])
    >>> type_per_edge = torch.tensor([0, 1, 0, 1, 1, 0, 0])
972
    >>> graph = graphbolt.fused_csc_sampling_graph(csc_indptr, indices,
973
974
    ...             node_type_offset=node_type_offset,
    ...             type_per_edge=type_per_edge,
975
    ...             node_type_to_id=ntypes, edge_type_to_id=etypes,
976
    ...             node_attributes=None, edge_attributes=None,)
977
    >>> print(graph)
978
    FusedCSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]),
979
                     indices=tensor([1, 3, 0, 1, 2, 0, 3]),
980
                     total_num_nodes=3, total_num_edges=7)
981
    """
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
    if node_type_to_id is not None and edge_type_to_id is not None:
        node_types = list(node_type_to_id.keys())
        edge_types = list(edge_type_to_id.keys())
        node_type_ids = list(node_type_to_id.values())
        edge_type_ids = list(edge_type_to_id.values())

        # Validate node_type_to_id.
        assert all(
            isinstance(x, str) for x in node_types
        ), "Node type name should be string."
        assert all(
            isinstance(x, int) for x in node_type_ids
        ), "Node type id should be int."
        assert len(node_type_ids) == len(
            set(node_type_ids)
        ), "Multiple node types shoud not be mapped to a same id."
        # Validate edge_type_to_id.
        for edge_type in edge_types:
            src, edge, dst = etype_str_to_tuple(edge_type)
            assert isinstance(edge, str), "Edge type name should be string."
            assert (
                src in node_types
            ), f"Unrecognized node type {src} in edge type {edge_type}"
            assert (
                dst in node_types
            ), f"Unrecognized node type {dst} in edge type {edge_type}"
        assert all(
            isinstance(x, int) for x in edge_type_ids
        ), "Edge type id should be int."
        assert len(edge_type_ids) == len(
            set(edge_type_ids)
        ), "Multiple edge types shoud not be mapped to a same id."

        if node_type_offset is not None:
            assert len(node_type_to_id) + 1 == node_type_offset.size(
                0
            ), "node_type_offset length should be |ntypes| + 1."
1019
    return FusedCSCSamplingGraph(
1020
        torch.ops.graphbolt.fused_csc_sampling_graph(
1021
1022
1023
1024
            csc_indptr,
            indices,
            node_type_offset,
            type_per_edge,
1025
1026
            node_type_to_id,
            edge_type_to_id,
1027
            node_attributes,
1028
            edge_attributes,
1029
1030
1031
1032
        ),
    )


1033
1034
def load_from_shared_memory(
    shared_memory_name: str,
1035
1036
) -> FusedCSCSamplingGraph:
    """Load a FusedCSCSamplingGraph object from shared memory.
1037
1038
1039
1040
1041
1042
1043
1044

    Parameters
    ----------
    shared_memory_name : str
        Name of the shared memory.

    Returns
    -------
1045
1046
    FusedCSCSamplingGraph
        The loaded FusedCSCSamplingGraph object on shared memory.
1047
    """
1048
    return FusedCSCSamplingGraph(
1049
1050
1051
1052
        torch.ops.graphbolt.load_from_shared_memory(shared_memory_name),
    )


1053
def _csc_sampling_graph_str(graph: FusedCSCSamplingGraph) -> str:
1054
1055
1056
1057
1058
    """Internal function for converting a csc sampling graph to string
    representation.
    """
    csc_indptr_str = str(graph.csc_indptr)
    indices_str = str(graph.indices)
1059
1060
1061
1062
1063
1064
1065
1066
1067
    meta_str = f"num_nodes={graph.total_num_nodes}, num_edges={graph.num_edges}"
    if graph.node_type_offset is not None:
        meta_str += f", node_type_offset={graph.node_type_offset}"
    if graph.type_per_edge is not None:
        meta_str += f", type_per_edge={graph.type_per_edge}"
    if graph.node_type_to_id is not None:
        meta_str += f", node_type_to_id={graph.node_type_to_id}"
    if graph.edge_type_to_id is not None:
        meta_str += f", edge_type_to_id={graph.edge_type_to_id}"
1068
1069
    if graph.node_attributes is not None:
        meta_str += f", node_attributes={graph.node_attributes}"
1070
1071
1072
    if graph.edge_attributes is not None:
        meta_str += f", edge_attributes={graph.edge_attributes}"

1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
    prefix = f"{type(graph).__name__}("

    def _add_indent(_str, indent):
        lines = _str.split("\n")
        lines = [lines[0]] + [" " * indent + line for line in lines[1:]]
        return "\n".join(lines)

    final_str = (
        "csc_indptr="
        + _add_indent(csc_indptr_str, len("csc_indptr="))
        + ",\n"
        + "indices="
        + _add_indent(indices_str, len("indices="))
        + ",\n"
        + meta_str
        + ")"
    )

    final_str = prefix + _add_indent(final_str, len(prefix))
    return final_str
1093
1094


1095
1096
1097
1098
def from_dglgraph(
    g: DGLGraph,
    is_homogeneous: bool = False,
    include_original_edge_id: bool = False,
1099
1100
) -> FusedCSCSamplingGraph:
    """Convert a DGLGraph to FusedCSCSamplingGraph."""
1101

1102
    homo_g, ntype_count, _ = to_homogeneous(g, return_count=True)
1103
1104

    if is_homogeneous:
1105
1106
        node_type_to_id = None
        edge_type_to_id = None
1107
1108
1109
1110
1111
1112
1113
    else:
        # Initialize metadata.
        node_type_to_id = {ntype: g.get_ntype_id(ntype) for ntype in g.ntypes}
        edge_type_to_id = {
            etype_tuple_to_str(etype): g.get_etype_id(etype)
            for etype in g.canonical_etypes
        }
1114
1115

    # Obtain CSC matrix.
1116
    indptr, indices, edge_ids = homo_g.adj_tensors("csc")
1117
    ntype_count.insert(0, 0)
1118
1119
1120
1121
1122
    node_type_offset = (
        None
        if is_homogeneous
        else torch.cumsum(torch.LongTensor(ntype_count), 0)
    )
1123

1124
1125
    # Assign edge type according to the order of CSC matrix.
    type_per_edge = None if is_homogeneous else homo_g.edata[ETYPE][edge_ids]
1126

1127
1128
    node_attributes = {}

1129
1130
1131
1132
    edge_attributes = {}
    if include_original_edge_id:
        # Assign edge attributes according to the original eids mapping.
        edge_attributes[ORIGINAL_EDGE_ID] = homo_g.edata[EID][edge_ids]
1133

1134
    return FusedCSCSamplingGraph(
1135
        torch.ops.graphbolt.fused_csc_sampling_graph(
1136
1137
1138
1139
            indptr,
            indices,
            node_type_offset,
            type_per_edge,
1140
1141
            node_type_to_id,
            edge_type_to_id,
1142
            node_attributes,
1143
            edge_attributes,
1144
1145
        ),
    )