"vscode:/vscode.git/clone" did not exist on "edc154da0906ddc7ade2dfea739917266908451a"
fused_csc_sampling_graph.py 42.9 KB
Newer Older
1
2
"""CSC format sampling graph."""
# pylint: disable= invalid-name
3
from collections import defaultdict
4
from typing import Dict, Optional, Union
5
6
7

import torch

8
9
from dgl.utils import recursive_apply

10
from ...base import EID, ETYPE
11
12
from ...convert import to_homogeneous
from ...heterograph import DGLGraph
13
from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID
14
from ..sampling_graph import SamplingGraph
15
16
17
18
19
from .sampled_subgraph_impl import (
    CSCFormatBase,
    FusedSampledSubgraphImpl,
    SampledSubgraphImpl,
)
20

21

22
__all__ = [
23
    "FusedCSCSamplingGraph",
24
    "fused_csc_sampling_graph",
25
26
27
28
29
    "load_from_shared_memory",
    "from_dglgraph",
]


30
class FusedCSCSamplingGraph(SamplingGraph):
31
    r"""A sampling graph in CSC format."""
32
33
34
35
36

    def __repr__(self):
        return _csc_sampling_graph_str(self)

    def __init__(
37
38
        self,
        c_csc_graph: torch.ScriptObject,
39
    ):
40
        super().__init__()
41
42
43
        self._c_csc_graph = c_csc_graph

    @property
44
    def total_num_nodes(self) -> int:
45
46
47
48
49
50
51
52
53
54
        """Returns the number of nodes in the graph.

        Returns
        -------
        int
            The number of rows in the dense format.
        """
        return self._c_csc_graph.num_nodes()

    @property
55
    def total_num_edges(self) -> int:
56
57
58
59
60
61
62
63
64
        """Returns the number of edges in the graph.

        Returns
        -------
        int
            The number of edges in the graph.
        """
        return self._c_csc_graph.num_edges()

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    @property
    def num_nodes(self) -> Union[int, Dict[str, int]]:
        """The number of nodes in the graph.
        - If the graph is homogenous, returns an integer.
        - If the graph is heterogenous, returns a dictionary.

        Returns
        -------
        Union[int, Dict[str, int]]
            The number of nodes. Integer indicates the total nodes number of a
            homogenous graph; dict indicates nodes number per node types of a
            heterogenous graph.

        Examples
        --------
        >>> import dgl.graphbolt as gb, torch
        >>> total_num_nodes = 5
        >>> total_num_edges = 12
        >>> ntypes = {"N0": 0, "N1": 1}
        >>> etypes = {"N0:R0:N0": 0, "N0:R1:N1": 1,
85
        ...     "N1:R2:N0": 2, "N1:R3:N1": 3}
86
87
88
89
        >>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
        >>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor(
90
        ...     [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
91
        >>> graph = gb.fused_csc_sampling_graph(indptr, indices,
92
93
94
95
        ...     node_type_offset=node_type_offset,
        ...     type_per_edge=type_per_edge,
        ...     node_type_to_id=ntypes,
        ...     edge_type_to_id=etypes)
96
        >>> print(graph.num_nodes)
97
        {'N0': 2, 'N1': 3}
98
99
100
101
102
        """

        offset = self.node_type_offset

        # Homogenous.
103
        if offset is None or self.node_type_to_id is None:
104
105
106
107
108
            return self._c_csc_graph.num_nodes()

        # Heterogenous
        else:
            num_nodes_per_type = {
109
                _type: (offset[_idx + 1] - offset[_idx]).item()
110
                for _type, _idx in self.node_type_to_id.items()
111
112
113
114
            }

            return num_nodes_per_type

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    @property
    def num_edges(self) -> Union[int, Dict[str, int]]:
        """The number of edges in the graph.
        - If the graph is homogenous, returns an integer.
        - If the graph is heterogenous, returns a dictionary.

        Returns
        -------
        Union[int, Dict[str, int]]
            The number of edges. Integer indicates the total edges number of a
            homogenous graph; dict indicates edges number per edge types of a
            heterogenous graph.

        Examples
        --------
        >>> import dgl.graphbolt as gb, torch
        >>> total_num_nodes = 5
        >>> total_num_edges = 12
        >>> ntypes = {"N0": 0, "N1": 1}
        >>> etypes = {"N0:R0:N0": 0, "N0:R1:N1": 1,
        ...     "N1:R2:N0": 2, "N1:R3:N1": 3}
        >>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
        >>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor(
        ...     [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
        >>> metadata = gb.GraphMetadata(ntypes, etypes)
142
        >>> graph = gb.fused_csc_sampling_graph(indptr, indices, node_type_offset,
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        ...     type_per_edge, None, metadata)
        >>> print(graph.num_edges)
        {'N0:R0:N0': 2, 'N0:R1:N1': 1, 'N1:R2:N0': 2, 'N1:R3:N1': 3}
        """

        type_per_edge = self.type_per_edge

        # Homogenous.
        if type_per_edge is None or self.edge_type_to_id is None:
            return self._c_csc_graph.num_edges()

        # Heterogenous
        bincount = torch.bincount(type_per_edge)
        num_edges_per_type = {}
        for etype, etype_id in self.edge_type_to_id.items():
            if etype_id < len(bincount):
                num_edges_per_type[etype] = bincount[etype_id].item()
            else:
                num_edges_per_type[etype] = 0
        return num_edges_per_type

164
165
166
167
168
169
170
171
    @property
    def csc_indptr(self) -> torch.tensor:
        """Returns the indices pointer in the CSC graph.

        Returns
        -------
        torch.tensor
            The indices pointer in the CSC graph. An integer tensor with
172
            shape `(total_num_nodes+1,)`.
173
174
175
        """
        return self._c_csc_graph.csc_indptr()

176
177
178
179
180
    @csc_indptr.setter
    def csc_indptr(self, csc_indptr: torch.tensor) -> None:
        """Sets the indices pointer in the CSC graph."""
        self._c_csc_graph.set_csc_indptr(csc_indptr)

181
182
183
184
185
186
187
188
    @property
    def indices(self) -> torch.tensor:
        """Returns the indices in the CSC graph.

        Returns
        -------
        torch.tensor
            The indices in the CSC graph. An integer tensor with shape
189
            `(total_num_edges,)`.
190
191
192
193
194
195
196
197

        Notes
        -------
        It is assumed that edges of each node are already sorted by edge type
        ids.
        """
        return self._c_csc_graph.indices()

198
199
200
201
202
    @indices.setter
    def indices(self, indices: torch.tensor) -> None:
        """Sets the indices in the CSC graph."""
        self._c_csc_graph.set_indices(indices)

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    @property
    def node_type_offset(self) -> Optional[torch.Tensor]:
        """Returns the node type offset tensor if present.

        Returns
        -------
        torch.Tensor or None
            If present, returns a 1D integer tensor of shape
            `(num_node_types + 1,)`. The tensor is in ascending order as nodes
            of the same type have continuous IDs, and larger node IDs are
            paired with larger node type IDs. The first value is 0 and last
            value is the number of nodes. And nodes with IDs between
            `node_type_offset_[i]~node_type_offset_[i+1]` are of type id 'i'.

        """
        return self._c_csc_graph.node_type_offset()

220
221
222
223
224
225
226
    @node_type_offset.setter
    def node_type_offset(
        self, node_type_offset: Optional[torch.Tensor]
    ) -> None:
        """Sets the node type offset tensor if present."""
        self._c_csc_graph.set_node_type_offset(node_type_offset)

227
228
229
230
231
232
233
    @property
    def type_per_edge(self) -> Optional[torch.Tensor]:
        """Returns the edge type tensor if present.

        Returns
        -------
        torch.Tensor or None
234
            If present, returns a 1D integer tensor of shape (total_num_edges,)
235
236
237
238
            containing the type of each edge in the graph.
        """
        return self._c_csc_graph.type_per_edge()

239
240
241
242
243
    @type_per_edge.setter
    def type_per_edge(self, type_per_edge: Optional[torch.Tensor]) -> None:
        """Sets the edge type tensor if present."""
        self._c_csc_graph.set_type_per_edge(type_per_edge)

244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    @property
    def node_type_to_id(self) -> Optional[Dict[str, int]]:
        """Returns the node type to id dictionary if present.

        Returns
        -------
        Dict[str, int] or None
            If present, returns a dictionary mapping node type to node type
            id.
        """
        return self._c_csc_graph.node_type_to_id()

    @node_type_to_id.setter
    def node_type_to_id(
        self, node_type_to_id: Optional[Dict[str, int]]
    ) -> None:
        """Sets the node type to id dictionary if present."""
        self._c_csc_graph.set_node_type_to_id(node_type_to_id)

    @property
    def edge_type_to_id(self) -> Optional[Dict[str, int]]:
        """Returns the edge type to id dictionary if present.

        Returns
        -------
        Dict[str, int] or None
            If present, returns a dictionary mapping edge type to edge type
            id.
        """
        return self._c_csc_graph.edge_type_to_id()

    @edge_type_to_id.setter
    def edge_type_to_id(
        self, edge_type_to_id: Optional[Dict[str, int]]
    ) -> None:
        """Sets the edge type to id dictionary if present."""
        self._c_csc_graph.set_edge_type_to_id(edge_type_to_id)

282
283
284
285
286
287
    @property
    def edge_attributes(self) -> Optional[Dict[str, torch.Tensor]]:
        """Returns the edge attributes dictionary.

        Returns
        -------
288
        Dict[str, torch.Tensor] or None
289
290
291
292
293
294
295
            If present, returns a dictionary of edge attributes. Each key
            represents the attribute's name, while the corresponding value
            holds the attribute's specific value. The length of each value
            should match the total number of edges."
        """
        return self._c_csc_graph.edge_attributes()

296
297
298
299
300
301
302
    @edge_attributes.setter
    def edge_attributes(
        self, edge_attributes: Optional[Dict[str, torch.Tensor]]
    ) -> None:
        """Sets the edge attributes dictionary."""
        self._c_csc_graph.set_edge_attributes(edge_attributes)

303
    def in_subgraph(
304
305
306
307
308
        self,
        nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
        # TODO: clean up once the migration is done.
        output_cscformat=False,
    ) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
309
310
311
        """Return the subgraph induced on the inbound edges of the given nodes.

        An in subgraph is equivalent to creating a new graph using the incoming
312
313
        edges of the given nodes. Subgraph is compacted according to the order
        of passed-in `nodes`.
314
315
316

        Parameters
        ----------
317
318
319
320
321
322
        nodes: torch.Tensor or Dict[str, torch.Tensor]
            IDs of the given seed nodes.
              - If `nodes` is a tensor: It means the graph is homogeneous
                graph, and ids inside are homogeneous ids.
              - If `nodes` is a dictionary: The keys should be node type and
                ids inside are heterogeneous ids.
323
324
325

        Returns
        -------
326
        FusedSampledSubgraphImpl
327
            The in subgraph.
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342

        Examples
        --------
        >>> import dgl.graphbolt as gb
        >>> import torch
        >>> total_num_nodes = 5
        >>> total_num_edges = 12
        >>> ntypes = {"N0": 0, "N1": 1}
        >>> etypes = {
        ...     "N0:R0:N0": 0, "N0:R1:N1": 1, "N1:R2:N0": 2, "N1:R3:N1": 3}
        >>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
        >>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor(
        ...     [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
343
        >>> graph = gb.fused_csc_sampling_graph(indptr, indices,
344
345
346
347
        ...     node_type_offset=node_type_offset,
        ...     type_per_edge=type_per_edge,
        ...     node_type_to_id=ntypes,
        ...     edge_type_to_id=etypes)
348
349
350
351
352
353
354
355
        >>> nodes = {"N0":torch.LongTensor([1]), "N1":torch.LongTensor([1, 2])}
        >>> in_subgraph = graph.in_subgraph(nodes)
        >>> print(in_subgraph.node_pairs)
        defaultdict(<class 'list'>, {
            'N0:R0:N0': (tensor([]), tensor([])),
            'N0:R1:N1': (tensor([1, 0]), tensor([1, 2])),
            'N1:R2:N0': (tensor([0, 1]), tensor([1, 1])),
            'N1:R3:N1': (tensor([0, 1, 2]), tensor([1, 2, 2]))}
356
        """
357
358
        if isinstance(nodes, dict):
            nodes = self._convert_to_homogeneous_nodes(nodes)
359
360
361
362
363
364
        # Ensure nodes is 1-D tensor.
        assert nodes.dim() == 1, "Nodes should be 1-D tensor."
        # Ensure that there are no duplicate nodes.
        assert len(torch.unique(nodes)) == len(
            nodes
        ), "Nodes cannot have duplicate values."
365

366
        _in_subgraph = self._c_csc_graph.in_subgraph(nodes)
367
368
369
370
        if not output_cscformat:
            return self._convert_to_fused_sampled_subgraph(_in_subgraph)
        else:
            return self._convert_to_sampled_subgraph(_in_subgraph)
371

372
    def _convert_to_fused_sampled_subgraph(
373
374
375
        self,
        C_sampled_subgraph: torch.ScriptObject,
    ):
376
        """An internal function used to convert a fused homogeneous sampled
377
        subgraph to general struct 'FusedSampledSubgraphImpl'."""
378
379
380
        column_num = (
            C_sampled_subgraph.indptr[1:] - C_sampled_subgraph.indptr[:-1]
        )
381
        column = C_sampled_subgraph.original_column_node_ids.repeat_interleave(
382
383
384
385
            column_num
        )
        row = C_sampled_subgraph.indices
        type_per_edge = C_sampled_subgraph.type_per_edge
386
387
388
389
390
391
392
393
394
        original_edge_ids = C_sampled_subgraph.original_edge_ids
        has_original_eids = (
            self.edge_attributes is not None
            and ORIGINAL_EDGE_ID in self.edge_attributes
        )
        if has_original_eids:
            original_edge_ids = self.edge_attributes[ORIGINAL_EDGE_ID][
                original_edge_ids
            ]
395
396
397
398
399
400
401
        if type_per_edge is None:
            # The sampled graph is already a homogeneous graph.
            node_pairs = (row, column)
        else:
            # The sampled graph is a fused homogenized graph, which need to be
            # converted to heterogeneous graphs.
            node_pairs = defaultdict(list)
402
            original_hetero_edge_ids = {}
403
            for etype, etype_id in self.edge_type_to_id.items():
404
                src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
405
406
                src_ntype_id = self.node_type_to_id[src_ntype]
                dst_ntype_id = self.node_type_to_id[dst_ntype]
407
408
409
410
411
412
                mask = type_per_edge == etype_id
                hetero_row = row[mask] - self.node_type_offset[src_ntype_id]
                hetero_column = (
                    column[mask] - self.node_type_offset[dst_ntype_id]
                )
                node_pairs[etype] = (hetero_row, hetero_column)
413
414
415
416
                if has_original_eids:
                    original_hetero_edge_ids[etype] = original_edge_ids[mask]
            if has_original_eids:
                original_edge_ids = original_hetero_edge_ids
417
        return FusedSampledSubgraphImpl(
418
419
            node_pairs=node_pairs, original_edge_ids=original_edge_ids
        )
420

421
422
423
    def _convert_to_homogeneous_nodes(self, nodes):
        homogeneous_nodes = []
        for ntype, ids in nodes.items():
424
            ntype_id = self.node_type_to_id[ntype]
425
426
427
            homogeneous_nodes.append(ids + self.node_type_offset[ntype_id])
        return torch.cat(homogeneous_nodes)

428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
    def _convert_to_sampled_subgraph(
        self,
        C_sampled_subgraph: torch.ScriptObject,
    ) -> SampledSubgraphImpl:
        """An internal function used to convert a fused homogeneous sampled
        subgraph to general struct 'SampledSubgraphImpl'."""
        indptr = C_sampled_subgraph.indptr
        indices = C_sampled_subgraph.indices
        type_per_edge = C_sampled_subgraph.type_per_edge
        column = C_sampled_subgraph.original_column_node_ids
        original_edge_ids = C_sampled_subgraph.original_edge_ids

        has_original_eids = (
            self.edge_attributes is not None
            and ORIGINAL_EDGE_ID in self.edge_attributes
        )
        if has_original_eids:
            original_edge_ids = self.edge_attributes[ORIGINAL_EDGE_ID][
                original_edge_ids
            ]
        if type_per_edge is None:
            # The sampled graph is already a homogeneous graph.
            node_pairs = CSCFormatBase(indptr=indptr, indices=indices)
        else:
            # The sampled graph is a fused homogenized graph, which need to be
            # converted to heterogeneous graphs.
            # Pre-calculate the number of each etype
            num = {}
            for etype in type_per_edge:
                num[etype.item()] = num.get(etype.item(), 0) + 1
            # Preallocate
            subgraph_indice_position = {}
            subgraph_indice = {}
            subgraph_indptr = {}
            node_edge_type = defaultdict(list)
            original_hetero_edge_ids = {}
464
            for etype, etype_id in self.edge_type_to_id.items():
465
466
467
468
469
470
471
472
473
474
475
476
                subgraph_indice[etype] = torch.empty(
                    (num.get(etype_id, 0),), dtype=indices.dtype
                )
                if has_original_eids:
                    original_hetero_edge_ids[etype] = torch.empty(
                        (num.get(etype_id, 0),), dtype=original_edge_ids.dtype
                    )
                subgraph_indptr[etype] = [0]
                subgraph_indice_position[etype] = 0
                # Preprocessing saves the type of seed_nodes as the edge type
                # of dst_ntype.
                _, _, dst_ntype = etype_str_to_tuple(etype)
477
                dst_ntype_id = self.node_type_to_id[dst_ntype]
478
479
                node_edge_type[dst_ntype_id].append((etype, etype_id))
            # construct subgraphs
480
            for i, seed in enumerate(column):
481
482
483
484
485
486
487
488
                l = indptr[i].item()
                r = indptr[i + 1].item()
                node_type = (
                    torch.searchsorted(
                        self.node_type_offset, seed, right=True
                    ).item()
                    - 1
                )
489
                for etype, etype_id in node_edge_type[node_type]:
490
                    src_ntype, _, _ = etype_str_to_tuple(etype)
491
                    src_ntype_id = self.node_type_to_id[src_ntype]
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
                    num_edges = torch.searchsorted(
                        type_per_edge[l:r], etype_id, right=True
                    ).item()
                    end = num_edges + l
                    subgraph_indptr[etype].append(
                        subgraph_indptr[etype][-1] + num_edges
                    )
                    offset = subgraph_indice_position[etype]
                    subgraph_indice_position[etype] += num_edges
                    subgraph_indice[etype][offset : offset + num_edges] = (
                        indices[l:end] - self.node_type_offset[src_ntype_id]
                    )
                    if has_original_eids:
                        original_hetero_edge_ids[etype][
                            offset : offset + num_edges
                        ] = original_edge_ids[l:end]
                    l = end
            if has_original_eids:
                original_edge_ids = original_hetero_edge_ids
            node_pairs = {
                etype: CSCFormatBase(
                    indptr=torch.tensor(subgraph_indptr[etype]),
                    indices=subgraph_indice[etype],
                )
516
                for etype in self.edge_type_to_id.keys()
517
518
519
520
521
522
            }
        return SampledSubgraphImpl(
            node_pairs=node_pairs,
            original_edge_ids=original_edge_ids,
        )

523
    def sample_neighbors(
524
525
526
527
528
        self,
        nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
        fanouts: torch.Tensor,
        replace: bool = False,
        probs_name: Optional[str] = None,
529
530
        # TODO: clean up once the migration is done.
        output_cscformat=False,
531
    ) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
532
533
534
535
536
537
538
        """Sample neighboring edges of the given nodes and return the induced
        subgraph.

        Parameters
        ----------
        nodes: torch.Tensor or Dict[str, torch.Tensor]
            IDs of the given seed nodes.
539
540
541
542
              - If `nodes` is a tensor: It means the graph is homogeneous
                graph, and ids inside are homogeneous ids.
              - If `nodes` is a dictionary: The keys should be node type and
                ids inside are heterogeneous ids.
543
544
545
546
547
548
549
550
551
552
        fanouts: torch.Tensor
            The number of edges to be sampled for each node with or without
            considering edge types.
              - When the length is 1, it indicates that the fanout applies to
                all neighbors of the node as a collective, regardless of the
                edge type.
              - Otherwise, the length should equal to the number of edge
                types, and each fanout value corresponds to a specific edge
                type of the nodes.
            The value of each fanout should be >= 0 or = -1.
553
554
555
556
557
              - When the value is -1, all neighbors (with non-zero probability,
                if weighted) will be sampled once regardless of replacement. It
                is equivalent to selecting all neighbors with non-zero
                probability when the fanout is >= the number of neighbors (and
                replace is set to false).
558
559
560
561
562
563
564
              - When the value is a non-negative integer, it serves as a
                minimum threshold for selecting neighbors.
        replace: bool
            Boolean indicating whether the sample is preformed with or
            without replacement. If True, a value can be selected multiple
            times. Otherwise, each value can be selected only once.
        probs_name: str, optional
565
566
            An optional string specifying the name of an edge attribute used.
            This attribute tensor should contain (unnormalized) probabilities
567
568
569
            corresponding to each neighboring edge of a node. It must be a 1D
            floating-point or boolean tensor, with the number of elements
            equalling the total number of edges.
570

571
572
        Returns
        -------
573
        FusedSampledSubgraphImpl
574
575
576
577
578
            The sampled subgraph.

        Examples
        --------
        >>> import dgl.graphbolt as gb
579
580
581
582
583
584
585
        >>> import torch
        >>> ntypes = {"n1": 0, "n2": 1}
        >>> etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
        >>> indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
        >>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
586
        >>> graph = gb.fused_csc_sampling_graph(indptr, indices,
587
588
589
590
        ...     node_type_offset=node_type_offset,
        ...     type_per_edge=type_per_edge,
        ...     node_type_to_id=ntypes,
        ...     edge_type_to_id=etypes)
591
        >>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
592
593
594
        >>> fanouts = torch.tensor([1, 1])
        >>> subgraph = graph.sample_neighbors(nodes, fanouts)
        >>> print(subgraph.node_pairs)
595
596
        defaultdict(<class 'list'>, {'n1:e1:n2': (tensor([0]),
          tensor([0])), 'n2:e2:n1': (tensor([2]), tensor([0]))})
597
598
        """
        if isinstance(nodes, dict):
599
            nodes = self._convert_to_homogeneous_nodes(nodes)
600
601

        C_sampled_subgraph = self._sample_neighbors(
602
            nodes, fanouts, replace, probs_name
603
        )
604
        if not output_cscformat:
605
606
607
            return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph)
        else:
            return self._convert_to_sampled_subgraph(C_sampled_subgraph)
608

609
610
611
612
    def _check_sampler_arguments(self, nodes, fanouts, probs_name):
        assert nodes.dim() == 1, "Nodes should be 1-D tensor."
        assert fanouts.dim() == 1, "Fanouts should be 1-D tensor."
        expected_fanout_len = 1
613
614
        if self.edge_type_to_id:
            expected_fanout_len = len(self.edge_type_to_id)
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
        assert len(fanouts) in [
            expected_fanout_len,
            1,
        ], "Fanouts should have the same number of elements as etypes or \
            should have a length of 1."
        if fanouts.size(0) > 1:
            assert (
                self.type_per_edge is not None
            ), "To perform sampling for each edge type (when the length of \
                `fanouts` > 1), the graph must include edge type information."
        assert torch.all(
            (fanouts >= 0) | (fanouts == -1)
        ), "Fanouts should consist of values that are either -1 or \
            greater than or equal to 0."
        if probs_name:
            assert (
                probs_name in self.edge_attributes
            ), f"Unknown edge attribute '{probs_name}'."
            probs_or_mask = self.edge_attributes[probs_name]
            assert probs_or_mask.dim() == 1, "Probs should be 1-D tensor."
            assert (
636
                probs_or_mask.size(0) == self.total_num_edges
637
638
639
640
641
642
643
644
645
646
            ), "Probs should have the same number of elements as the number \
                of edges."
            assert probs_or_mask.dtype in [
                torch.bool,
                torch.float16,
                torch.bfloat16,
                torch.float32,
                torch.float64,
            ], "Probs should have a floating-point or boolean data type."

647
    def _sample_neighbors(
648
649
        self,
        nodes: torch.Tensor,
650
        fanouts: torch.Tensor,
651
        replace: bool = False,
652
        probs_name: Optional[str] = None,
653
    ) -> torch.ScriptObject:
654
655
656
657
658
659
660
        """Sample neighboring edges of the given nodes and return the induced
        subgraph.

        Parameters
        ----------
        nodes: torch.Tensor
            IDs of the given seed nodes.
661
662
663
664
        fanouts: torch.Tensor
            The number of edges to be sampled for each node with or without
            considering edge types.
              - When the length is 1, it indicates that the fanout applies to
665
666
                all neighbors of the node as a collective, regardless of the
                edge type.
667
              - Otherwise, the length should equal to the number of edge
668
669
                types, and each fanout value corresponds to a specific edge
                type of the nodes.
670
            The value of each fanout should be >= 0 or = -1.
671
672
673
674
675
              - When the value is -1, all neighbors (with non-zero probability,
                if weighted) will be sampled once regardless of replacement. It
                is equivalent to selecting all neighbors with non-zero
                probability when the fanout is >= the number of neighbors (and
                replace is set to false).
676
              - When the value is a non-negative integer, it serves as a
677
678
                minimum threshold for selecting neighbors.
        replace: bool
679
680
681
            Boolean indicating whether the sample is preformed with or
            without replacement. If True, a value can be selected multiple
            times. Otherwise, each value can be selected only once.
682
683
684
685
686
687
        probs_name: str, optional
            An optional string specifying the name of an edge attribute. This
            attribute tensor should contain (unnormalized) probabilities
            corresponding to each neighboring edge of a node. It must be a 1D
            floating-point or boolean tensor, with the number of elements
            equalling the total number of edges.
688

689
        Returns
Rhett Ying's avatar
Rhett Ying committed
690
        -------
691
        torch.classes.graphbolt.SampledSubgraph
692
            The sampled C subgraph.
693
694
        """
        # Ensure nodes is 1-D tensor.
695
        self._check_sampler_arguments(nodes, fanouts, probs_name)
696
        has_original_eids = (
697
698
699
            self.edge_attributes is not None
            and ORIGINAL_EDGE_ID in self.edge_attributes
        )
700
        return self._c_csc_graph.sample_neighbors(
701
702
703
704
705
706
            nodes,
            fanouts.tolist(),
            replace,
            False,
            has_original_eids,
            probs_name,
707
708
709
710
711
712
713
714
        )

    def sample_layer_neighbors(
        self,
        nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
        fanouts: torch.Tensor,
        replace: bool = False,
        probs_name: Optional[str] = None,
715
716
        # TODO: clean up once the migration is done.
        output_cscformat=False,
717
    ) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
718
        """Sample neighboring edges of the given nodes and return the induced
719
720
721
        subgraph via layer-neighbor sampling from the NeurIPS 2023 paper
        `Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs
        <https://arxiv.org/abs/2210.13339>`__
722
723
724
725
726

        Parameters
        ----------
        nodes: torch.Tensor or Dict[str, torch.Tensor]
            IDs of the given seed nodes.
727
728
729
730
              - If `nodes` is a tensor: It means the graph is homogeneous
                graph, and ids inside are homogeneous ids.
              - If `nodes` is a dictionary: The keys should be node type and
                ids inside are heterogeneous ids.
731
732
733
734
735
736
737
738
739
740
        fanouts: torch.Tensor
            The number of edges to be sampled for each node with or without
            considering edge types.
              - When the length is 1, it indicates that the fanout applies to
                all neighbors of the node as a collective, regardless of the
                edge type.
              - Otherwise, the length should equal to the number of edge
                types, and each fanout value corresponds to a specific edge
                type of the nodes.
            The value of each fanout should be >= 0 or = -1.
741
742
743
744
745
              - When the value is -1, all neighbors (with non-zero probability,
                if weighted) will be sampled once regardless of replacement. It
                is equivalent to selecting all neighbors with non-zero
                probability when the fanout is >= the number of neighbors (and
                replace is set to false).
746
747
748
749
750
751
752
753
754
755
756
757
              - When the value is a non-negative integer, it serves as a
                minimum threshold for selecting neighbors.
        replace: bool
            Boolean indicating whether the sample is preformed with or
            without replacement. If True, a value can be selected multiple
            times. Otherwise, each value can be selected only once.
        probs_name: str, optional
            An optional string specifying the name of an edge attribute. This
            attribute tensor should contain (unnormalized) probabilities
            corresponding to each neighboring edge of a node. It must be a 1D
            floating-point or boolean tensor, with the number of elements
            equalling the total number of edges.
758

759
760
        Returns
        -------
761
        FusedSampledSubgraphImpl
762
763
764
765
            The sampled subgraph.

        Examples
        --------
766
767
768
769
770
771
772
773
        >>> import dgl.graphbolt as gb
        >>> import torch
        >>> ntypes = {"n1": 0, "n2": 1}
        >>> etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
        >>> indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
        >>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
774
        >>> graph = gb.fused_csc_sampling_graph(indptr, indices,
775
776
777
778
        ...     node_type_offset=node_type_offset,
        ...     type_per_edge=type_per_edge,
        ...     node_type_to_id=ntypes,
        ...     edge_type_to_id=etypes)
779
780
781
782
783
784
        >>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
        >>> fanouts = torch.tensor([1, 1])
        >>> subgraph = graph.sample_layer_neighbors(nodes, fanouts)
        >>> print(subgraph.node_pairs)
        defaultdict(<class 'list'>, {'n1:e1:n2': (tensor([1]),
          tensor([0])), 'n2:e2:n1': (tensor([2]), tensor([0]))})
785
786
787
788
789
        """
        if isinstance(nodes, dict):
            nodes = self._convert_to_homogeneous_nodes(nodes)

        self._check_sampler_arguments(nodes, fanouts, probs_name)
790
791
792
793
        has_original_eids = (
            self.edge_attributes is not None
            and ORIGINAL_EDGE_ID in self.edge_attributes
        )
794
        C_sampled_subgraph = self._c_csc_graph.sample_neighbors(
795
796
797
798
799
800
            nodes,
            fanouts.tolist(),
            replace,
            True,
            has_original_eids,
            probs_name,
801
        )
802

803
        if not output_cscformat:
804
805
806
            return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph)
        else:
            return self._convert_to_sampled_subgraph(C_sampled_subgraph)
807

808
809
810
811
812
813
814
815
816
817
818
819
    def sample_negative_edges_uniform(
        self, edge_type, node_pairs, negative_ratio
    ):
        """
        Sample negative edges by randomly choosing negative source-destination
        pairs according to a uniform distribution. For each edge ``(u, v)``,
        it is supposed to generate `negative_ratio` pairs of negative edges
        ``(u, v')``, where ``v'`` is chosen uniformly from all the nodes in
        the graph.

        Parameters
        ----------
820
        edge_type: str
821
822
823
            The type of edges in the provided node_pairs. Any negative edges
            sampled will also have the same type. If set to None, it will be
            considered as a homogeneous graph.
824
        node_pairs : Tuple[Tensor, Tensor]
825
826
827
828
829
830
831
832
833
834
            A tuple of two 1D tensors that represent the source and destination
            of positive edges, with 'positive' indicating that these edges are
            present in the graph. It's important to note that within the
            context of a heterogeneous graph, the ids in these tensors signify
            heterogeneous ids.
        negative_ratio: int
            The ratio of the number of negative samples to positive samples.

        Returns
        -------
835
        Tuple[Tensor, Tensor]
836
837
838
839
840
841
842
            A tuple consisting of two 1D tensors represents the source and
            destination of negative edges. In the context of a heterogeneous
            graph, both the input nodes and the selected nodes are represented
            by heterogeneous IDs, and the formed edges are of the input type
            `edge_type`. Note that negative refers to false negatives, which
            means the edge could be present or not present in the graph.
        """
843
        if edge_type is not None:
844
845
846
847
            assert (
                self.node_type_offset is not None
            ), "The 'node_type_offset' array is necessary for performing \
                negative sampling by edge type."
848
            _, _, dst_node_type = etype_str_to_tuple(edge_type)
849
            dst_node_type_id = self.node_type_to_id[dst_node_type]
850
851
852
853
854
            max_node_id = (
                self.node_type_offset[dst_node_type_id + 1]
                - self.node_type_offset[dst_node_type_id]
            )
        else:
855
            max_node_id = self.total_num_nodes
856
857
858
859
860
861
        return self._c_csc_graph.sample_negative_edges_uniform(
            node_pairs,
            negative_ratio,
            max_node_id,
        )

862
863
864
865
866
867
868
869
870
871
    def copy_to_shared_memory(self, shared_memory_name: str):
        """Copy the graph to shared memory.

        Parameters
        ----------
        shared_memory_name : str
            Name of the shared memory.

        Returns
        -------
872
873
        FusedCSCSamplingGraph
            The copied FusedCSCSamplingGraph object on shared memory.
874
        """
875
        return FusedCSCSamplingGraph(
876
877
878
            self._c_csc_graph.copy_to_shared_memory(shared_memory_name),
        )

879
    def to(self, device: torch.device) -> None:  # pylint: disable=invalid-name
880
        """Copy `FusedCSCSamplingGraph` to the specified device."""
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900

        def _to(x, device):
            return x.to(device) if hasattr(x, "to") else x

        self.csc_indptr = recursive_apply(
            self.csc_indptr, lambda x: _to(x, device)
        )
        self.indices = recursive_apply(self.indices, lambda x: _to(x, device))
        self.node_type_offset = recursive_apply(
            self.node_type_offset, lambda x: _to(x, device)
        )
        self.type_per_edge = recursive_apply(
            self.type_per_edge, lambda x: _to(x, device)
        )
        self.edge_attributes = recursive_apply(
            self.edge_attributes, lambda x: _to(x, device)
        )

        return self

901

902
def fused_csc_sampling_graph(
903
904
905
906
    csc_indptr: torch.Tensor,
    indices: torch.Tensor,
    node_type_offset: Optional[torch.tensor] = None,
    type_per_edge: Optional[torch.tensor] = None,
907
908
    node_type_to_id: Optional[Dict[str, int]] = None,
    edge_type_to_id: Optional[Dict[str, int]] = None,
909
    edge_attributes: Optional[Dict[str, torch.tensor]] = None,
910
911
) -> FusedCSCSamplingGraph:
    """Create a FusedCSCSamplingGraph object from a CSC representation.
912
913
914
915
916

    Parameters
    ----------
    csc_indptr : torch.Tensor
        Pointer to the start of each row in the `indices`. An integer tensor
917
        with shape `(total_num_nodes+1,)`.
918
919
    indices : torch.Tensor
        Column indices of the non-zero elements in the CSC graph. An integer
920
        tensor with shape `(total_num_edges,)`.
921
922
923
924
    node_type_offset : Optional[torch.tensor], optional
        Offset of node types in the graph, by default None.
    type_per_edge : Optional[torch.tensor], optional
        Type ids of each edge in the graph, by default None.
925
926
927
928
    node_type_to_id : Optional[Dict[str, int]], optional
        Map node types to ids, by default None.
    edge_type_to_id : Optional[Dict[str, int]], optional
        Map edge types to ids, by default None.
929
930
    edge_attributes: Optional[Dict[str, torch.tensor]], optional
        Edge attributes of the graph, by default None.
931

932
933
    Returns
    -------
934
935
    FusedCSCSamplingGraph
        The created FusedCSCSamplingGraph object.
936
937
938
939

    Examples
    --------
    >>> ntypes = {'n1': 0, 'n2': 1, 'n3': 2}
940
    >>> etypes = {'n1:e1:n2': 0, 'n1:e2:n3': 1}
941
942
943
944
    >>> csc_indptr = torch.tensor([0, 2, 5, 7])
    >>> indices = torch.tensor([1, 3, 0, 1, 2, 0, 3])
    >>> node_type_offset = torch.tensor([0, 1, 2, 3])
    >>> type_per_edge = torch.tensor([0, 1, 0, 1, 1, 0, 0])
945
    >>> graph = graphbolt.fused_csc_sampling_graph(csc_indptr, indices,
946
947
    ...             node_type_offset=node_type_offset,
    ...             type_per_edge=type_per_edge,
948
949
    ...             node_type_to_id=ntypes, edge_type_to_id=etypes,
    ...             edge_attributes=None,)
950
    >>> print(graph)
951
    FusedCSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]),
952
                     indices=tensor([1, 3, 0, 1, 2, 0, 3]),
953
                     total_num_nodes=3, total_num_edges=7)
954
    """
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
    if node_type_to_id is not None and edge_type_to_id is not None:
        node_types = list(node_type_to_id.keys())
        edge_types = list(edge_type_to_id.keys())
        node_type_ids = list(node_type_to_id.values())
        edge_type_ids = list(edge_type_to_id.values())

        # Validate node_type_to_id.
        assert all(
            isinstance(x, str) for x in node_types
        ), "Node type name should be string."
        assert all(
            isinstance(x, int) for x in node_type_ids
        ), "Node type id should be int."
        assert len(node_type_ids) == len(
            set(node_type_ids)
        ), "Multiple node types shoud not be mapped to a same id."
        # Validate edge_type_to_id.
        for edge_type in edge_types:
            src, edge, dst = etype_str_to_tuple(edge_type)
            assert isinstance(edge, str), "Edge type name should be string."
            assert (
                src in node_types
            ), f"Unrecognized node type {src} in edge type {edge_type}"
            assert (
                dst in node_types
            ), f"Unrecognized node type {dst} in edge type {edge_type}"
        assert all(
            isinstance(x, int) for x in edge_type_ids
        ), "Edge type id should be int."
        assert len(edge_type_ids) == len(
            set(edge_type_ids)
        ), "Multiple edge types shoud not be mapped to a same id."

        if node_type_offset is not None:
            assert len(node_type_to_id) + 1 == node_type_offset.size(
                0
            ), "node_type_offset length should be |ntypes| + 1."
992
    return FusedCSCSamplingGraph(
993
        torch.ops.graphbolt.fused_csc_sampling_graph(
994
995
996
997
            csc_indptr,
            indices,
            node_type_offset,
            type_per_edge,
998
999
            node_type_to_id,
            edge_type_to_id,
1000
            edge_attributes,
1001
1002
1003
1004
        ),
    )


1005
1006
def load_from_shared_memory(
    shared_memory_name: str,
1007
1008
) -> FusedCSCSamplingGraph:
    """Load a FusedCSCSamplingGraph object from shared memory.
1009
1010
1011
1012
1013
1014
1015
1016

    Parameters
    ----------
    shared_memory_name : str
        Name of the shared memory.

    Returns
    -------
1017
1018
    FusedCSCSamplingGraph
        The loaded FusedCSCSamplingGraph object on shared memory.
1019
    """
1020
    return FusedCSCSamplingGraph(
1021
1022
1023
1024
        torch.ops.graphbolt.load_from_shared_memory(shared_memory_name),
    )


1025
def _csc_sampling_graph_str(graph: FusedCSCSamplingGraph) -> str:
1026
1027
1028
1029
1030
    """Internal function for converting a csc sampling graph to string
    representation.
    """
    csc_indptr_str = str(graph.csc_indptr)
    indices_str = str(graph.indices)
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
    meta_str = f"num_nodes={graph.total_num_nodes}, num_edges={graph.num_edges}"
    if graph.node_type_offset is not None:
        meta_str += f", node_type_offset={graph.node_type_offset}"
    if graph.type_per_edge is not None:
        meta_str += f", type_per_edge={graph.type_per_edge}"
    if graph.node_type_to_id is not None:
        meta_str += f", node_type_to_id={graph.node_type_to_id}"
    if graph.edge_type_to_id is not None:
        meta_str += f", edge_type_to_id={graph.edge_type_to_id}"
    if graph.edge_attributes is not None:
        meta_str += f", edge_attributes={graph.edge_attributes}"

1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
    prefix = f"{type(graph).__name__}("

    def _add_indent(_str, indent):
        lines = _str.split("\n")
        lines = [lines[0]] + [" " * indent + line for line in lines[1:]]
        return "\n".join(lines)

    final_str = (
        "csc_indptr="
        + _add_indent(csc_indptr_str, len("csc_indptr="))
        + ",\n"
        + "indices="
        + _add_indent(indices_str, len("indices="))
        + ",\n"
        + meta_str
        + ")"
    )

    final_str = prefix + _add_indent(final_str, len(prefix))
    return final_str
1063
1064


1065
1066
1067
1068
def from_dglgraph(
    g: DGLGraph,
    is_homogeneous: bool = False,
    include_original_edge_id: bool = False,
1069
1070
) -> FusedCSCSamplingGraph:
    """Convert a DGLGraph to FusedCSCSamplingGraph."""
1071

1072
    homo_g, ntype_count, _ = to_homogeneous(g, return_count=True)
1073
1074

    if is_homogeneous:
1075
1076
        node_type_to_id = None
        edge_type_to_id = None
1077
1078
1079
1080
1081
1082
1083
    else:
        # Initialize metadata.
        node_type_to_id = {ntype: g.get_ntype_id(ntype) for ntype in g.ntypes}
        edge_type_to_id = {
            etype_tuple_to_str(etype): g.get_etype_id(etype)
            for etype in g.canonical_etypes
        }
1084
1085

    # Obtain CSC matrix.
1086
    indptr, indices, edge_ids = homo_g.adj_tensors("csc")
1087
    ntype_count.insert(0, 0)
1088
1089
1090
1091
1092
    node_type_offset = (
        None
        if is_homogeneous
        else torch.cumsum(torch.LongTensor(ntype_count), 0)
    )
1093

1094
1095
    # Assign edge type according to the order of CSC matrix.
    type_per_edge = None if is_homogeneous else homo_g.edata[ETYPE][edge_ids]
1096

1097
1098
1099
1100
    edge_attributes = {}
    if include_original_edge_id:
        # Assign edge attributes according to the original eids mapping.
        edge_attributes[ORIGINAL_EDGE_ID] = homo_g.edata[EID][edge_ids]
1101

1102
    return FusedCSCSamplingGraph(
1103
        torch.ops.graphbolt.fused_csc_sampling_graph(
1104
1105
1106
1107
            indptr,
            indices,
            node_type_offset,
            type_per_edge,
1108
1109
            node_type_to_id,
            edge_type_to_id,
1110
            edge_attributes,
1111
1112
        ),
    )