"vscode:/vscode.git/clone" did not exist on "37d992ecf69680e2367eac1a9dcba3de528710d2"
fused_csc_sampling_graph.py 43.3 KB
Newer Older
1
2
"""CSC format sampling graph."""
# pylint: disable= invalid-name
3
from collections import defaultdict
4
from typing import Dict, Optional, Union
5
6
7

import torch

8
9
from dgl.utils import recursive_apply

10
from ...base import EID, ETYPE
11
12
from ...convert import to_homogeneous
from ...heterograph import DGLGraph
13
from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID
14
from ..sampling_graph import SamplingGraph
15
16
17
18
19
from .sampled_subgraph_impl import (
    CSCFormatBase,
    FusedSampledSubgraphImpl,
    SampledSubgraphImpl,
)
20

21

22
23
__all__ = [
    "GraphMetadata",
24
25
    "FusedCSCSamplingGraph",
    "from_fused_csc",
26
27
28
29
30
    "load_from_shared_memory",
    "from_dglgraph",
]


31
32
33
34
35
36
class GraphMetadata:
    r"""Class for metadata of csc sampling graph."""

    def __init__(
        self,
        node_type_to_id: Dict[str, int],
37
        edge_type_to_id: Dict[str, int],
38
39
40
41
42
43
44
    ):
        """Initialize the GraphMetadata object.

        Parameters
        ----------
        node_type_to_id : Dict[str, int]
            Dictionary from node types to node type IDs.
45
        edge_type_to_id : Dict[str, int]
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
            Dictionary from edge types to edge type IDs.

        Raises
        ------
        AssertionError
            If any of the assertions fail.
        """

        node_types = list(node_type_to_id.keys())
        edge_types = list(edge_type_to_id.keys())
        node_type_ids = list(node_type_to_id.values())
        edge_type_ids = list(edge_type_to_id.values())

        # Validate node_type_to_id.
        assert all(
            isinstance(x, str) for x in node_types
        ), "Node type name should be string."
        assert all(
            isinstance(x, int) for x in node_type_ids
        ), "Node type id should be int."
        assert len(node_type_ids) == len(
            set(node_type_ids)
        ), "Multiple node types shoud not be mapped to a same id."
        # Validate edge_type_to_id.
        for edge_type in edge_types:
71
            src, edge, dst = etype_str_to_tuple(edge_type)
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
            assert isinstance(edge, str), "Edge type name should be string."
            assert (
                src in node_types
            ), f"Unrecognized node type {src} in edge type {edge_type}"
            assert (
                dst in node_types
            ), f"Unrecognized node type {dst} in edge type {edge_type}"
        assert all(
            isinstance(x, int) for x in edge_type_ids
        ), "Edge type id should be int."
        assert len(edge_type_ids) == len(
            set(edge_type_ids)
        ), "Multiple edge types shoud not be mapped to a same id."

        self.node_type_to_id = node_type_to_id
        self.edge_type_to_id = edge_type_to_id


90
class FusedCSCSamplingGraph(SamplingGraph):
91
    r"""A sampling graph in CSC format."""
92
93
94
95
96

    def __repr__(self):
        return _csc_sampling_graph_str(self)

    def __init__(
97
98
        self,
        c_csc_graph: torch.ScriptObject,
99
    ):
100
        super().__init__()
101
102
103
        self._c_csc_graph = c_csc_graph

    @property
104
    def total_num_nodes(self) -> int:
105
106
107
108
109
110
111
112
113
114
        """Returns the number of nodes in the graph.

        Returns
        -------
        int
            The number of rows in the dense format.
        """
        return self._c_csc_graph.num_nodes()

    @property
115
    def total_num_edges(self) -> int:
116
117
118
119
120
121
122
123
124
        """Returns the number of edges in the graph.

        Returns
        -------
        int
            The number of edges in the graph.
        """
        return self._c_csc_graph.num_edges()

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    @property
    def num_nodes(self) -> Union[int, Dict[str, int]]:
        """The number of nodes in the graph.
        - If the graph is homogenous, returns an integer.
        - If the graph is heterogenous, returns a dictionary.

        Returns
        -------
        Union[int, Dict[str, int]]
            The number of nodes. Integer indicates the total nodes number of a
            homogenous graph; dict indicates nodes number per node types of a
            heterogenous graph.

        Examples
        --------
        >>> import dgl.graphbolt as gb, torch
        >>> total_num_nodes = 5
        >>> total_num_edges = 12
        >>> ntypes = {"N0": 0, "N1": 1}
        >>> etypes = {"N0:R0:N0": 0, "N0:R1:N1": 1,
145
        ...     "N1:R2:N0": 2, "N1:R3:N1": 3}
146
147
148
149
        >>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
        >>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor(
150
        ...     [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
151
        >>> metadata = gb.GraphMetadata(ntypes, etypes)
152
        >>> graph = gb.from_fused_csc(indptr, indices, node_type_offset,
153
        ...     type_per_edge, None, metadata)
154
        >>> print(graph.num_nodes)
155
        {'N0': 2, 'N1': 3}
156
157
158
159
160
161
162
163
164
165
166
        """

        offset = self.node_type_offset

        # Homogenous.
        if offset is None or self.metadata is None:
            return self._c_csc_graph.num_nodes()

        # Heterogenous
        else:
            num_nodes_per_type = {
167
                _type: (offset[_idx + 1] - offset[_idx]).item()
168
169
170
171
172
                for _type, _idx in self.metadata.node_type_to_id.items()
            }

            return num_nodes_per_type

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
    @property
    def num_edges(self) -> Union[int, Dict[str, int]]:
        """The number of edges in the graph.
        - If the graph is homogenous, returns an integer.
        - If the graph is heterogenous, returns a dictionary.

        Returns
        -------
        Union[int, Dict[str, int]]
            The number of edges. Integer indicates the total edges number of a
            homogenous graph; dict indicates edges number per edge types of a
            heterogenous graph.

        Examples
        --------
        >>> import dgl.graphbolt as gb, torch
        >>> total_num_nodes = 5
        >>> total_num_edges = 12
        >>> ntypes = {"N0": 0, "N1": 1}
        >>> etypes = {"N0:R0:N0": 0, "N0:R1:N1": 1,
        ...     "N1:R2:N0": 2, "N1:R3:N1": 3}
        >>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
        >>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor(
        ...     [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
        >>> metadata = gb.GraphMetadata(ntypes, etypes)
        >>> graph = gb.from_fused_csc(indptr, indices, node_type_offset,
        ...     type_per_edge, None, metadata)
        >>> print(graph.num_edges)
        {'N0:R0:N0': 2, 'N0:R1:N1': 1, 'N1:R2:N0': 2, 'N1:R3:N1': 3}
        """

        type_per_edge = self.type_per_edge

        # Homogenous.
        if type_per_edge is None or self.edge_type_to_id is None:
            return self._c_csc_graph.num_edges()

        # Heterogenous
        bincount = torch.bincount(type_per_edge)
        num_edges_per_type = {}
        for etype, etype_id in self.edge_type_to_id.items():
            if etype_id < len(bincount):
                num_edges_per_type[etype] = bincount[etype_id].item()
            else:
                num_edges_per_type[etype] = 0
        return num_edges_per_type

222
223
224
225
226
227
228
229
    @property
    def csc_indptr(self) -> torch.tensor:
        """Returns the indices pointer in the CSC graph.

        Returns
        -------
        torch.tensor
            The indices pointer in the CSC graph. An integer tensor with
230
            shape `(total_num_nodes+1,)`.
231
232
233
        """
        return self._c_csc_graph.csc_indptr()

234
235
236
237
238
    @csc_indptr.setter
    def csc_indptr(self, csc_indptr: torch.tensor) -> None:
        """Sets the indices pointer in the CSC graph."""
        self._c_csc_graph.set_csc_indptr(csc_indptr)

239
240
241
242
243
244
245
246
    @property
    def indices(self) -> torch.tensor:
        """Returns the indices in the CSC graph.

        Returns
        -------
        torch.tensor
            The indices in the CSC graph. An integer tensor with shape
247
            `(total_num_edges,)`.
248
249
250
251
252
253
254
255

        Notes
        -------
        It is assumed that edges of each node are already sorted by edge type
        ids.
        """
        return self._c_csc_graph.indices()

256
257
258
259
260
    @indices.setter
    def indices(self, indices: torch.tensor) -> None:
        """Sets the indices in the CSC graph."""
        self._c_csc_graph.set_indices(indices)

261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
    @property
    def node_type_offset(self) -> Optional[torch.Tensor]:
        """Returns the node type offset tensor if present.

        Returns
        -------
        torch.Tensor or None
            If present, returns a 1D integer tensor of shape
            `(num_node_types + 1,)`. The tensor is in ascending order as nodes
            of the same type have continuous IDs, and larger node IDs are
            paired with larger node type IDs. The first value is 0 and last
            value is the number of nodes. And nodes with IDs between
            `node_type_offset_[i]~node_type_offset_[i+1]` are of type id 'i'.

        """
        return self._c_csc_graph.node_type_offset()

278
279
280
281
282
283
284
    @node_type_offset.setter
    def node_type_offset(
        self, node_type_offset: Optional[torch.Tensor]
    ) -> None:
        """Sets the node type offset tensor if present."""
        self._c_csc_graph.set_node_type_offset(node_type_offset)

285
286
287
288
289
290
291
    @property
    def type_per_edge(self) -> Optional[torch.Tensor]:
        """Returns the edge type tensor if present.

        Returns
        -------
        torch.Tensor or None
292
            If present, returns a 1D integer tensor of shape (total_num_edges,)
293
294
295
296
            containing the type of each edge in the graph.
        """
        return self._c_csc_graph.type_per_edge()

297
298
299
300
301
    @type_per_edge.setter
    def type_per_edge(self, type_per_edge: Optional[torch.Tensor]) -> None:
        """Sets the edge type tensor if present."""
        self._c_csc_graph.set_type_per_edge(type_per_edge)

302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
    @property
    def node_type_to_id(self) -> Optional[Dict[str, int]]:
        """Returns the node type to id dictionary if present.

        Returns
        -------
        Dict[str, int] or None
            If present, returns a dictionary mapping node type to node type
            id.
        """
        return self._c_csc_graph.node_type_to_id()

    @node_type_to_id.setter
    def node_type_to_id(
        self, node_type_to_id: Optional[Dict[str, int]]
    ) -> None:
        """Sets the node type to id dictionary if present."""
        self._c_csc_graph.set_node_type_to_id(node_type_to_id)

    @property
    def edge_type_to_id(self) -> Optional[Dict[str, int]]:
        """Returns the edge type to id dictionary if present.

        Returns
        -------
        Dict[str, int] or None
            If present, returns a dictionary mapping edge type to edge type
            id.
        """
        return self._c_csc_graph.edge_type_to_id()

    @edge_type_to_id.setter
    def edge_type_to_id(
        self, edge_type_to_id: Optional[Dict[str, int]]
    ) -> None:
        """Sets the edge type to id dictionary if present."""
        self._c_csc_graph.set_edge_type_to_id(edge_type_to_id)

340
341
342
343
344
345
    @property
    def edge_attributes(self) -> Optional[Dict[str, torch.Tensor]]:
        """Returns the edge attributes dictionary.

        Returns
        -------
346
        Dict[str, torch.Tensor] or None
347
348
349
350
351
352
353
            If present, returns a dictionary of edge attributes. Each key
            represents the attribute's name, while the corresponding value
            holds the attribute's specific value. The length of each value
            should match the total number of edges."
        """
        return self._c_csc_graph.edge_attributes()

354
355
356
357
358
359
360
    @edge_attributes.setter
    def edge_attributes(
        self, edge_attributes: Optional[Dict[str, torch.Tensor]]
    ) -> None:
        """Sets the edge attributes dictionary."""
        self._c_csc_graph.set_edge_attributes(edge_attributes)

361
362
363
364
    @property
    def metadata(self) -> Optional[GraphMetadata]:
        """Returns the metadata of the graph.

365
366
        [TODO][Rui] This API needs to be updated.

367
368
369
370
371
        Returns
        -------
        GraphMetadata or None
            If present, returns the metadata of the graph.
        """
372
373
374
        if self.node_type_to_id is None or self.edge_type_to_id is None:
            return None
        return GraphMetadata(self.node_type_to_id, self.edge_type_to_id)
375

376
377
378
    def in_subgraph(
        self, nodes: Union[torch.Tensor, Dict[str, torch.Tensor]]
    ) -> FusedSampledSubgraphImpl:
379
380
381
        """Return the subgraph induced on the inbound edges of the given nodes.

        An in subgraph is equivalent to creating a new graph using the incoming
382
383
        edges of the given nodes. Subgraph is compacted according to the order
        of passed-in `nodes`.
384
385
386

        Parameters
        ----------
387
388
389
390
391
392
        nodes: torch.Tensor or Dict[str, torch.Tensor]
            IDs of the given seed nodes.
              - If `nodes` is a tensor: It means the graph is homogeneous
                graph, and ids inside are homogeneous ids.
              - If `nodes` is a dictionary: The keys should be node type and
                ids inside are heterogeneous ids.
393
394
395

        Returns
        -------
396
        FusedSampledSubgraphImpl
397
            The in subgraph.
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423

        Examples
        --------
        >>> import dgl.graphbolt as gb
        >>> import torch
        >>> total_num_nodes = 5
        >>> total_num_edges = 12
        >>> ntypes = {"N0": 0, "N1": 1}
        >>> etypes = {
        ...     "N0:R0:N0": 0, "N0:R1:N1": 1, "N1:R2:N0": 2, "N1:R3:N1": 3}
        >>> metadata = gb.GraphMetadata(ntypes, etypes)
        >>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
        >>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor(
        ...     [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
        >>> graph = gb.from_fused_csc(indptr, indices, node_type_offset,
        ...     type_per_edge, None, metadata)
        >>> nodes = {"N0":torch.LongTensor([1]), "N1":torch.LongTensor([1, 2])}
        >>> in_subgraph = graph.in_subgraph(nodes)
        >>> print(in_subgraph.node_pairs)
        defaultdict(<class 'list'>, {
            'N0:R0:N0': (tensor([]), tensor([])),
            'N0:R1:N1': (tensor([1, 0]), tensor([1, 2])),
            'N1:R2:N0': (tensor([0, 1]), tensor([1, 1])),
            'N1:R3:N1': (tensor([0, 1, 2]), tensor([1, 2, 2]))}
424
        """
425
426
        if isinstance(nodes, dict):
            nodes = self._convert_to_homogeneous_nodes(nodes)
427
428
429
430
431
432
        # Ensure nodes is 1-D tensor.
        assert nodes.dim() == 1, "Nodes should be 1-D tensor."
        # Ensure that there are no duplicate nodes.
        assert len(torch.unique(nodes)) == len(
            nodes
        ), "Nodes cannot have duplicate values."
433

434
        _in_subgraph = self._c_csc_graph.in_subgraph(nodes)
435
        return self._convert_to_fused_sampled_subgraph(_in_subgraph)
436

437
    def _convert_to_fused_sampled_subgraph(
438
439
440
        self,
        C_sampled_subgraph: torch.ScriptObject,
    ):
441
        """An internal function used to convert a fused homogeneous sampled
442
        subgraph to general struct 'FusedSampledSubgraphImpl'."""
443
444
445
        column_num = (
            C_sampled_subgraph.indptr[1:] - C_sampled_subgraph.indptr[:-1]
        )
446
        column = C_sampled_subgraph.original_column_node_ids.repeat_interleave(
447
448
449
450
            column_num
        )
        row = C_sampled_subgraph.indices
        type_per_edge = C_sampled_subgraph.type_per_edge
451
452
453
454
455
456
457
458
459
        original_edge_ids = C_sampled_subgraph.original_edge_ids
        has_original_eids = (
            self.edge_attributes is not None
            and ORIGINAL_EDGE_ID in self.edge_attributes
        )
        if has_original_eids:
            original_edge_ids = self.edge_attributes[ORIGINAL_EDGE_ID][
                original_edge_ids
            ]
460
461
462
463
464
465
466
        if type_per_edge is None:
            # The sampled graph is already a homogeneous graph.
            node_pairs = (row, column)
        else:
            # The sampled graph is a fused homogenized graph, which need to be
            # converted to heterogeneous graphs.
            node_pairs = defaultdict(list)
467
            original_hetero_edge_ids = {}
468
            for etype, etype_id in self.metadata.edge_type_to_id.items():
469
                src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
470
471
472
473
474
475
476
477
                src_ntype_id = self.metadata.node_type_to_id[src_ntype]
                dst_ntype_id = self.metadata.node_type_to_id[dst_ntype]
                mask = type_per_edge == etype_id
                hetero_row = row[mask] - self.node_type_offset[src_ntype_id]
                hetero_column = (
                    column[mask] - self.node_type_offset[dst_ntype_id]
                )
                node_pairs[etype] = (hetero_row, hetero_column)
478
479
480
481
                if has_original_eids:
                    original_hetero_edge_ids[etype] = original_edge_ids[mask]
            if has_original_eids:
                original_edge_ids = original_hetero_edge_ids
482
        return FusedSampledSubgraphImpl(
483
484
            node_pairs=node_pairs, original_edge_ids=original_edge_ids
        )
485

486
487
488
489
490
491
492
    def _convert_to_homogeneous_nodes(self, nodes):
        homogeneous_nodes = []
        for ntype, ids in nodes.items():
            ntype_id = self.metadata.node_type_to_id[ntype]
            homogeneous_nodes.append(ids + self.node_type_offset[ntype_id])
        return torch.cat(homogeneous_nodes)

493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
    def _convert_to_sampled_subgraph(
        self,
        C_sampled_subgraph: torch.ScriptObject,
    ) -> SampledSubgraphImpl:
        """An internal function used to convert a fused homogeneous sampled
        subgraph to general struct 'SampledSubgraphImpl'."""
        indptr = C_sampled_subgraph.indptr
        indices = C_sampled_subgraph.indices
        type_per_edge = C_sampled_subgraph.type_per_edge
        column = C_sampled_subgraph.original_column_node_ids
        original_edge_ids = C_sampled_subgraph.original_edge_ids

        has_original_eids = (
            self.edge_attributes is not None
            and ORIGINAL_EDGE_ID in self.edge_attributes
        )
        if has_original_eids:
            original_edge_ids = self.edge_attributes[ORIGINAL_EDGE_ID][
                original_edge_ids
            ]
        if type_per_edge is None:
            # The sampled graph is already a homogeneous graph.
            node_pairs = CSCFormatBase(indptr=indptr, indices=indices)
        else:
            # The sampled graph is a fused homogenized graph, which need to be
            # converted to heterogeneous graphs.
            # Pre-calculate the number of each etype
            num = {}
            for etype in type_per_edge:
                num[etype.item()] = num.get(etype.item(), 0) + 1
            # Preallocate
            subgraph_indice_position = {}
            subgraph_indice = {}
            subgraph_indptr = {}
            node_edge_type = defaultdict(list)
            original_hetero_edge_ids = {}
            for etype, etype_id in self.metadata.edge_type_to_id.items():
                subgraph_indice[etype] = torch.empty(
                    (num.get(etype_id, 0),), dtype=indices.dtype
                )
                if has_original_eids:
                    original_hetero_edge_ids[etype] = torch.empty(
                        (num.get(etype_id, 0),), dtype=original_edge_ids.dtype
                    )
                subgraph_indptr[etype] = [0]
                subgraph_indice_position[etype] = 0
                # Preprocessing saves the type of seed_nodes as the edge type
                # of dst_ntype.
                _, _, dst_ntype = etype_str_to_tuple(etype)
                dst_ntype_id = self.metadata.node_type_to_id[dst_ntype]
                node_edge_type[dst_ntype_id].append((etype, etype_id))
            # construct subgraphs
545
            for i, seed in enumerate(column):
546
547
548
549
550
551
552
553
                l = indptr[i].item()
                r = indptr[i + 1].item()
                node_type = (
                    torch.searchsorted(
                        self.node_type_offset, seed, right=True
                    ).item()
                    - 1
                )
554
                for etype, etype_id in node_edge_type[node_type]:
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
                    src_ntype, _, _ = etype_str_to_tuple(etype)
                    src_ntype_id = self.metadata.node_type_to_id[src_ntype]
                    num_edges = torch.searchsorted(
                        type_per_edge[l:r], etype_id, right=True
                    ).item()
                    end = num_edges + l
                    subgraph_indptr[etype].append(
                        subgraph_indptr[etype][-1] + num_edges
                    )
                    offset = subgraph_indice_position[etype]
                    subgraph_indice_position[etype] += num_edges
                    subgraph_indice[etype][offset : offset + num_edges] = (
                        indices[l:end] - self.node_type_offset[src_ntype_id]
                    )
                    if has_original_eids:
                        original_hetero_edge_ids[etype][
                            offset : offset + num_edges
                        ] = original_edge_ids[l:end]
                    l = end
            if has_original_eids:
                original_edge_ids = original_hetero_edge_ids
            node_pairs = {
                etype: CSCFormatBase(
                    indptr=torch.tensor(subgraph_indptr[etype]),
                    indices=subgraph_indice[etype],
                )
                for etype in self.metadata.edge_type_to_id.keys()
            }
        return SampledSubgraphImpl(
            node_pairs=node_pairs,
            original_edge_ids=original_edge_ids,
        )

588
    def sample_neighbors(
589
590
591
592
593
        self,
        nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
        fanouts: torch.Tensor,
        replace: bool = False,
        probs_name: Optional[str] = None,
594
595
        # TODO: clean up once the migration is done.
        output_cscformat=False,
596
    ) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
597
598
599
600
601
602
603
        """Sample neighboring edges of the given nodes and return the induced
        subgraph.

        Parameters
        ----------
        nodes: torch.Tensor or Dict[str, torch.Tensor]
            IDs of the given seed nodes.
604
605
606
607
              - If `nodes` is a tensor: It means the graph is homogeneous
                graph, and ids inside are homogeneous ids.
              - If `nodes` is a dictionary: The keys should be node type and
                ids inside are heterogeneous ids.
608
609
610
611
612
613
614
615
616
617
        fanouts: torch.Tensor
            The number of edges to be sampled for each node with or without
            considering edge types.
              - When the length is 1, it indicates that the fanout applies to
                all neighbors of the node as a collective, regardless of the
                edge type.
              - Otherwise, the length should equal to the number of edge
                types, and each fanout value corresponds to a specific edge
                type of the nodes.
            The value of each fanout should be >= 0 or = -1.
618
619
620
621
622
              - When the value is -1, all neighbors (with non-zero probability,
                if weighted) will be sampled once regardless of replacement. It
                is equivalent to selecting all neighbors with non-zero
                probability when the fanout is >= the number of neighbors (and
                replace is set to false).
623
624
625
626
627
628
629
              - When the value is a non-negative integer, it serves as a
                minimum threshold for selecting neighbors.
        replace: bool
            Boolean indicating whether the sample is preformed with or
            without replacement. If True, a value can be selected multiple
            times. Otherwise, each value can be selected only once.
        probs_name: str, optional
630
631
            An optional string specifying the name of an edge attribute used.
            This attribute tensor should contain (unnormalized) probabilities
632
633
634
            corresponding to each neighboring edge of a node. It must be a 1D
            floating-point or boolean tensor, with the number of elements
            equalling the total number of edges.
635

636
637
        Returns
        -------
638
        FusedSampledSubgraphImpl
639
640
641
642
643
            The sampled subgraph.

        Examples
        --------
        >>> import dgl.graphbolt as gb
644
645
646
        >>> import torch
        >>> ntypes = {"n1": 0, "n2": 1}
        >>> etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
647
        >>> metadata = gb.GraphMetadata(ntypes, etypes)
648
649
650
651
        >>> indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
        >>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
652
        >>> graph = gb.from_fused_csc(indptr, indices, type_per_edge=type_per_edge,
653
654
        ...     node_type_offset=node_type_offset, metadata=metadata)
        >>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
655
656
657
        >>> fanouts = torch.tensor([1, 1])
        >>> subgraph = graph.sample_neighbors(nodes, fanouts)
        >>> print(subgraph.node_pairs)
658
659
        defaultdict(<class 'list'>, {'n1:e1:n2': (tensor([0]),
          tensor([0])), 'n2:e2:n1': (tensor([2]), tensor([0]))})
660
661
        """
        if isinstance(nodes, dict):
662
            nodes = self._convert_to_homogeneous_nodes(nodes)
663
664

        C_sampled_subgraph = self._sample_neighbors(
665
            nodes, fanouts, replace, probs_name
666
        )
667
        if not output_cscformat:
668
669
670
            return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph)
        else:
            return self._convert_to_sampled_subgraph(C_sampled_subgraph)
671

672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
    def _check_sampler_arguments(self, nodes, fanouts, probs_name):
        assert nodes.dim() == 1, "Nodes should be 1-D tensor."
        assert fanouts.dim() == 1, "Fanouts should be 1-D tensor."
        expected_fanout_len = 1
        if self.metadata and self.metadata.edge_type_to_id:
            expected_fanout_len = len(self.metadata.edge_type_to_id)
        assert len(fanouts) in [
            expected_fanout_len,
            1,
        ], "Fanouts should have the same number of elements as etypes or \
            should have a length of 1."
        if fanouts.size(0) > 1:
            assert (
                self.type_per_edge is not None
            ), "To perform sampling for each edge type (when the length of \
                `fanouts` > 1), the graph must include edge type information."
        assert torch.all(
            (fanouts >= 0) | (fanouts == -1)
        ), "Fanouts should consist of values that are either -1 or \
            greater than or equal to 0."
        if probs_name:
            assert (
                probs_name in self.edge_attributes
            ), f"Unknown edge attribute '{probs_name}'."
            probs_or_mask = self.edge_attributes[probs_name]
            assert probs_or_mask.dim() == 1, "Probs should be 1-D tensor."
            assert (
699
                probs_or_mask.size(0) == self.total_num_edges
700
701
702
703
704
705
706
707
708
709
            ), "Probs should have the same number of elements as the number \
                of edges."
            assert probs_or_mask.dtype in [
                torch.bool,
                torch.float16,
                torch.bfloat16,
                torch.float32,
                torch.float64,
            ], "Probs should have a floating-point or boolean data type."

710
    def _sample_neighbors(
711
712
        self,
        nodes: torch.Tensor,
713
        fanouts: torch.Tensor,
714
        replace: bool = False,
715
        probs_name: Optional[str] = None,
716
    ) -> torch.ScriptObject:
717
718
719
720
721
722
723
        """Sample neighboring edges of the given nodes and return the induced
        subgraph.

        Parameters
        ----------
        nodes: torch.Tensor
            IDs of the given seed nodes.
724
725
726
727
        fanouts: torch.Tensor
            The number of edges to be sampled for each node with or without
            considering edge types.
              - When the length is 1, it indicates that the fanout applies to
728
729
                all neighbors of the node as a collective, regardless of the
                edge type.
730
              - Otherwise, the length should equal to the number of edge
731
732
                types, and each fanout value corresponds to a specific edge
                type of the nodes.
733
            The value of each fanout should be >= 0 or = -1.
734
735
736
737
738
              - When the value is -1, all neighbors (with non-zero probability,
                if weighted) will be sampled once regardless of replacement. It
                is equivalent to selecting all neighbors with non-zero
                probability when the fanout is >= the number of neighbors (and
                replace is set to false).
739
              - When the value is a non-negative integer, it serves as a
740
741
                minimum threshold for selecting neighbors.
        replace: bool
742
743
744
            Boolean indicating whether the sample is preformed with or
            without replacement. If True, a value can be selected multiple
            times. Otherwise, each value can be selected only once.
745
746
747
748
749
750
        probs_name: str, optional
            An optional string specifying the name of an edge attribute. This
            attribute tensor should contain (unnormalized) probabilities
            corresponding to each neighboring edge of a node. It must be a 1D
            floating-point or boolean tensor, with the number of elements
            equalling the total number of edges.
751

752
        Returns
Rhett Ying's avatar
Rhett Ying committed
753
        -------
754
        torch.classes.graphbolt.SampledSubgraph
755
            The sampled C subgraph.
756
757
        """
        # Ensure nodes is 1-D tensor.
758
        self._check_sampler_arguments(nodes, fanouts, probs_name)
759
        has_original_eids = (
760
761
762
            self.edge_attributes is not None
            and ORIGINAL_EDGE_ID in self.edge_attributes
        )
763
        return self._c_csc_graph.sample_neighbors(
764
765
766
767
768
769
            nodes,
            fanouts.tolist(),
            replace,
            False,
            has_original_eids,
            probs_name,
770
771
772
773
774
775
776
777
        )

    def sample_layer_neighbors(
        self,
        nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
        fanouts: torch.Tensor,
        replace: bool = False,
        probs_name: Optional[str] = None,
778
779
        # TODO: clean up once the migration is done.
        output_cscformat=False,
780
    ) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
781
        """Sample neighboring edges of the given nodes and return the induced
782
783
784
        subgraph via layer-neighbor sampling from the NeurIPS 2023 paper
        `Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs
        <https://arxiv.org/abs/2210.13339>`__
785
786
787
788
789

        Parameters
        ----------
        nodes: torch.Tensor or Dict[str, torch.Tensor]
            IDs of the given seed nodes.
790
791
792
793
              - If `nodes` is a tensor: It means the graph is homogeneous
                graph, and ids inside are homogeneous ids.
              - If `nodes` is a dictionary: The keys should be node type and
                ids inside are heterogeneous ids.
794
795
796
797
798
799
800
801
802
803
        fanouts: torch.Tensor
            The number of edges to be sampled for each node with or without
            considering edge types.
              - When the length is 1, it indicates that the fanout applies to
                all neighbors of the node as a collective, regardless of the
                edge type.
              - Otherwise, the length should equal to the number of edge
                types, and each fanout value corresponds to a specific edge
                type of the nodes.
            The value of each fanout should be >= 0 or = -1.
804
805
806
807
808
              - When the value is -1, all neighbors (with non-zero probability,
                if weighted) will be sampled once regardless of replacement. It
                is equivalent to selecting all neighbors with non-zero
                probability when the fanout is >= the number of neighbors (and
                replace is set to false).
809
810
811
812
813
814
815
816
817
818
819
820
              - When the value is a non-negative integer, it serves as a
                minimum threshold for selecting neighbors.
        replace: bool
            Boolean indicating whether the sample is preformed with or
            without replacement. If True, a value can be selected multiple
            times. Otherwise, each value can be selected only once.
        probs_name: str, optional
            An optional string specifying the name of an edge attribute. This
            attribute tensor should contain (unnormalized) probabilities
            corresponding to each neighboring edge of a node. It must be a 1D
            floating-point or boolean tensor, with the number of elements
            equalling the total number of edges.
821

822
823
        Returns
        -------
824
        FusedSampledSubgraphImpl
825
826
827
828
            The sampled subgraph.

        Examples
        --------
829
830
831
832
833
834
835
836
837
        >>> import dgl.graphbolt as gb
        >>> import torch
        >>> ntypes = {"n1": 0, "n2": 1}
        >>> etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
        >>> metadata = gb.GraphMetadata(ntypes, etypes)
        >>> indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
        >>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
838
        >>> graph = gb.from_fused_csc(indptr, indices, type_per_edge=type_per_edge,
839
840
841
842
843
844
845
        ...     node_type_offset=node_type_offset, metadata=metadata)
        >>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
        >>> fanouts = torch.tensor([1, 1])
        >>> subgraph = graph.sample_layer_neighbors(nodes, fanouts)
        >>> print(subgraph.node_pairs)
        defaultdict(<class 'list'>, {'n1:e1:n2': (tensor([1]),
          tensor([0])), 'n2:e2:n1': (tensor([2]), tensor([0]))})
846
847
848
849
850
        """
        if isinstance(nodes, dict):
            nodes = self._convert_to_homogeneous_nodes(nodes)

        self._check_sampler_arguments(nodes, fanouts, probs_name)
851
852
853
854
        has_original_eids = (
            self.edge_attributes is not None
            and ORIGINAL_EDGE_ID in self.edge_attributes
        )
855
        C_sampled_subgraph = self._c_csc_graph.sample_neighbors(
856
857
858
859
860
861
            nodes,
            fanouts.tolist(),
            replace,
            True,
            has_original_eids,
            probs_name,
862
        )
863

864
        if not output_cscformat:
865
866
867
            return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph)
        else:
            return self._convert_to_sampled_subgraph(C_sampled_subgraph)
868

869
870
871
872
873
874
875
876
877
878
879
880
    def sample_negative_edges_uniform(
        self, edge_type, node_pairs, negative_ratio
    ):
        """
        Sample negative edges by randomly choosing negative source-destination
        pairs according to a uniform distribution. For each edge ``(u, v)``,
        it is supposed to generate `negative_ratio` pairs of negative edges
        ``(u, v')``, where ``v'`` is chosen uniformly from all the nodes in
        the graph.

        Parameters
        ----------
881
        edge_type: str
882
883
884
            The type of edges in the provided node_pairs. Any negative edges
            sampled will also have the same type. If set to None, it will be
            considered as a homogeneous graph.
885
        node_pairs : Tuple[Tensor, Tensor]
886
887
888
889
890
891
892
893
894
895
            A tuple of two 1D tensors that represent the source and destination
            of positive edges, with 'positive' indicating that these edges are
            present in the graph. It's important to note that within the
            context of a heterogeneous graph, the ids in these tensors signify
            heterogeneous ids.
        negative_ratio: int
            The ratio of the number of negative samples to positive samples.

        Returns
        -------
896
        Tuple[Tensor, Tensor]
897
898
899
900
901
902
903
            A tuple consisting of two 1D tensors represents the source and
            destination of negative edges. In the context of a heterogeneous
            graph, both the input nodes and the selected nodes are represented
            by heterogeneous IDs, and the formed edges are of the input type
            `edge_type`. Note that negative refers to false negatives, which
            means the edge could be present or not present in the graph.
        """
904
        if edge_type is not None:
905
906
907
908
            assert (
                self.node_type_offset is not None
            ), "The 'node_type_offset' array is necessary for performing \
                negative sampling by edge type."
909
            _, _, dst_node_type = etype_str_to_tuple(edge_type)
910
911
912
913
914
915
            dst_node_type_id = self.metadata.node_type_to_id[dst_node_type]
            max_node_id = (
                self.node_type_offset[dst_node_type_id + 1]
                - self.node_type_offset[dst_node_type_id]
            )
        else:
916
            max_node_id = self.total_num_nodes
917
918
919
920
921
922
        return self._c_csc_graph.sample_negative_edges_uniform(
            node_pairs,
            negative_ratio,
            max_node_id,
        )

923
924
925
926
927
928
929
930
931
932
    def copy_to_shared_memory(self, shared_memory_name: str):
        """Copy the graph to shared memory.

        Parameters
        ----------
        shared_memory_name : str
            Name of the shared memory.

        Returns
        -------
933
934
        FusedCSCSamplingGraph
            The copied FusedCSCSamplingGraph object on shared memory.
935
        """
936
        return FusedCSCSamplingGraph(
937
938
939
            self._c_csc_graph.copy_to_shared_memory(shared_memory_name),
        )

940
    def to(self, device: torch.device) -> None:  # pylint: disable=invalid-name
941
        """Copy `FusedCSCSamplingGraph` to the specified device."""
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961

        def _to(x, device):
            return x.to(device) if hasattr(x, "to") else x

        self.csc_indptr = recursive_apply(
            self.csc_indptr, lambda x: _to(x, device)
        )
        self.indices = recursive_apply(self.indices, lambda x: _to(x, device))
        self.node_type_offset = recursive_apply(
            self.node_type_offset, lambda x: _to(x, device)
        )
        self.type_per_edge = recursive_apply(
            self.type_per_edge, lambda x: _to(x, device)
        )
        self.edge_attributes = recursive_apply(
            self.edge_attributes, lambda x: _to(x, device)
        )

        return self

962

963
def from_fused_csc(
964
965
966
967
    csc_indptr: torch.Tensor,
    indices: torch.Tensor,
    node_type_offset: Optional[torch.tensor] = None,
    type_per_edge: Optional[torch.tensor] = None,
968
    edge_attributes: Optional[Dict[str, torch.tensor]] = None,
969
    metadata: Optional[GraphMetadata] = None,
970
971
) -> FusedCSCSamplingGraph:
    """Create a FusedCSCSamplingGraph object from a CSC representation.
972
973
974
975
976

    Parameters
    ----------
    csc_indptr : torch.Tensor
        Pointer to the start of each row in the `indices`. An integer tensor
977
        with shape `(total_num_nodes+1,)`.
978
979
    indices : torch.Tensor
        Column indices of the non-zero elements in the CSC graph. An integer
980
        tensor with shape `(total_num_edges,)`.
981
982
983
984
    node_type_offset : Optional[torch.tensor], optional
        Offset of node types in the graph, by default None.
    type_per_edge : Optional[torch.tensor], optional
        Type ids of each edge in the graph, by default None.
985
986
    edge_attributes: Optional[Dict[str, torch.tensor]], optional
        Edge attributes of the graph, by default None.
987
988
    metadata: Optional[GraphMetadata], optional
        Metadata of the graph, by default None.
989

990
991
    Returns
    -------
992
993
    FusedCSCSamplingGraph
        The created FusedCSCSamplingGraph object.
994
995
996
997
998
999
1000
1001
1002
1003

    Examples
    --------
    >>> ntypes = {'n1': 0, 'n2': 1, 'n3': 2}
    >>> etypes = {('n1', 'e1', 'n2'): 0, ('n1', 'e2', 'n3'): 1}
    >>> metadata = graphbolt.GraphMetadata(ntypes, etypes)
    >>> csc_indptr = torch.tensor([0, 2, 5, 7])
    >>> indices = torch.tensor([1, 3, 0, 1, 2, 0, 3])
    >>> node_type_offset = torch.tensor([0, 1, 2, 3])
    >>> type_per_edge = torch.tensor([0, 1, 0, 1, 1, 0, 0])
1004
    >>> graph = graphbolt.from_fused_csc(csc_indptr, indices,
1005
1006
1007
    ...             node_type_offset=node_type_offset,
    ...             type_per_edge=type_per_edge,
    ...             edge_attributes=None, metadata=metadata)
1008
    >>> print(graph)
1009
    FusedCSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]),
1010
                     indices=tensor([1, 3, 0, 1, 2, 0, 3]),
1011
                     total_num_nodes=3, total_num_edges=7)
1012
1013
1014
1015
1016
    """
    if metadata and metadata.node_type_to_id and node_type_offset is not None:
        assert len(metadata.node_type_to_id) + 1 == node_type_offset.size(
            0
        ), "node_type_offset length should be |ntypes| + 1."
1017
1018
    node_type_to_id = metadata.node_type_to_id if metadata else None
    edge_type_to_id = metadata.edge_type_to_id if metadata else None
1019
1020
    return FusedCSCSamplingGraph(
        torch.ops.graphbolt.from_fused_csc(
1021
1022
1023
1024
            csc_indptr,
            indices,
            node_type_offset,
            type_per_edge,
1025
1026
            node_type_to_id,
            edge_type_to_id,
1027
            edge_attributes,
1028
1029
1030
1031
        ),
    )


1032
1033
def load_from_shared_memory(
    shared_memory_name: str,
1034
1035
) -> FusedCSCSamplingGraph:
    """Load a FusedCSCSamplingGraph object from shared memory.
1036
1037
1038
1039
1040
1041
1042
1043

    Parameters
    ----------
    shared_memory_name : str
        Name of the shared memory.

    Returns
    -------
1044
1045
    FusedCSCSamplingGraph
        The loaded FusedCSCSamplingGraph object on shared memory.
1046
    """
1047
    return FusedCSCSamplingGraph(
1048
1049
1050
1051
        torch.ops.graphbolt.load_from_shared_memory(shared_memory_name),
    )


1052
def _csc_sampling_graph_str(graph: FusedCSCSamplingGraph) -> str:
1053
1054
1055
1056
1057
    """Internal function for converting a csc sampling graph to string
    representation.
    """
    csc_indptr_str = str(graph.csc_indptr)
    indices_str = str(graph.indices)
1058
1059
1060
1061
    meta_str = (
        f"total_num_nodes={graph.total_num_nodes}, total_num_edges="
        f"{graph.total_num_edges}"
    )
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
    prefix = f"{type(graph).__name__}("

    def _add_indent(_str, indent):
        lines = _str.split("\n")
        lines = [lines[0]] + [" " * indent + line for line in lines[1:]]
        return "\n".join(lines)

    final_str = (
        "csc_indptr="
        + _add_indent(csc_indptr_str, len("csc_indptr="))
        + ",\n"
        + "indices="
        + _add_indent(indices_str, len("indices="))
        + ",\n"
        + meta_str
        + ")"
    )

    final_str = prefix + _add_indent(final_str, len(prefix))
    return final_str
1082
1083


1084
1085
1086
1087
def from_dglgraph(
    g: DGLGraph,
    is_homogeneous: bool = False,
    include_original_edge_id: bool = False,
1088
1089
) -> FusedCSCSamplingGraph:
    """Convert a DGLGraph to FusedCSCSamplingGraph."""
1090

1091
    homo_g, ntype_count, _ = to_homogeneous(g, return_count=True)
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102

    if is_homogeneous:
        metadata = None
    else:
        # Initialize metadata.
        node_type_to_id = {ntype: g.get_ntype_id(ntype) for ntype in g.ntypes}
        edge_type_to_id = {
            etype_tuple_to_str(etype): g.get_etype_id(etype)
            for etype in g.canonical_etypes
        }
        metadata = GraphMetadata(node_type_to_id, edge_type_to_id)
1103
1104

    # Obtain CSC matrix.
1105
    indptr, indices, edge_ids = homo_g.adj_tensors("csc")
1106
    ntype_count.insert(0, 0)
1107
1108
1109
1110
1111
    node_type_offset = (
        None
        if is_homogeneous
        else torch.cumsum(torch.LongTensor(ntype_count), 0)
    )
1112

1113
1114
    # Assign edge type according to the order of CSC matrix.
    type_per_edge = None if is_homogeneous else homo_g.edata[ETYPE][edge_ids]
1115

1116
1117
1118
1119
    edge_attributes = {}
    if include_original_edge_id:
        # Assign edge attributes according to the original eids mapping.
        edge_attributes[ORIGINAL_EDGE_ID] = homo_g.edata[EID][edge_ids]
1120

1121
1122
    node_type_to_id = metadata.node_type_to_id if metadata else None
    edge_type_to_id = metadata.edge_type_to_id if metadata else None
1123
1124
    return FusedCSCSamplingGraph(
        torch.ops.graphbolt.from_fused_csc(
1125
1126
1127
1128
            indptr,
            indices,
            node_type_offset,
            type_per_edge,
1129
1130
            node_type_to_id,
            edge_type_to_id,
1131
            edge_attributes,
1132
1133
        ),
    )