fused_csc_sampling_graph.py 34.9 KB
Newer Older
1
2
"""CSC format sampling graph."""
# pylint: disable= invalid-name
3
4
5
import os
import tarfile
import tempfile
6
from collections import defaultdict
7
from typing import Dict, Optional, Union
8
9
10

import torch

11
12
from dgl.utils import recursive_apply

13
from ...base import EID, ETYPE
14
15
from ...convert import to_homogeneous
from ...heterograph import DGLGraph
16
from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID
17
from ..sampling_graph import SamplingGraph
18
from .sampled_subgraph_impl import FusedSampledSubgraphImpl
19

20

21
22
__all__ = [
    "GraphMetadata",
23
24
    "FusedCSCSamplingGraph",
    "from_fused_csc",
25
    "load_from_shared_memory",
26
27
    "load_fused_csc_sampling_graph",
    "save_fused_csc_sampling_graph",
28
29
30
31
    "from_dglgraph",
]


32
33
34
35
36
37
class GraphMetadata:
    r"""Class for metadata of csc sampling graph."""

    def __init__(
        self,
        node_type_to_id: Dict[str, int],
38
        edge_type_to_id: Dict[str, int],
39
40
41
42
43
44
45
    ):
        """Initialize the GraphMetadata object.

        Parameters
        ----------
        node_type_to_id : Dict[str, int]
            Dictionary from node types to node type IDs.
46
        edge_type_to_id : Dict[str, int]
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
            Dictionary from edge types to edge type IDs.

        Raises
        ------
        AssertionError
            If any of the assertions fail.
        """

        node_types = list(node_type_to_id.keys())
        edge_types = list(edge_type_to_id.keys())
        node_type_ids = list(node_type_to_id.values())
        edge_type_ids = list(edge_type_to_id.values())

        # Validate node_type_to_id.
        assert all(
            isinstance(x, str) for x in node_types
        ), "Node type name should be string."
        assert all(
            isinstance(x, int) for x in node_type_ids
        ), "Node type id should be int."
        assert len(node_type_ids) == len(
            set(node_type_ids)
        ), "Multiple node types shoud not be mapped to a same id."
        # Validate edge_type_to_id.
        for edge_type in edge_types:
72
            src, edge, dst = etype_str_to_tuple(edge_type)
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
            assert isinstance(edge, str), "Edge type name should be string."
            assert (
                src in node_types
            ), f"Unrecognized node type {src} in edge type {edge_type}"
            assert (
                dst in node_types
            ), f"Unrecognized node type {dst} in edge type {edge_type}"
        assert all(
            isinstance(x, int) for x in edge_type_ids
        ), "Edge type id should be int."
        assert len(edge_type_ids) == len(
            set(edge_type_ids)
        ), "Multiple edge types shoud not be mapped to a same id."

        self.node_type_to_id = node_type_to_id
        self.edge_type_to_id = edge_type_to_id


91
class FusedCSCSamplingGraph(SamplingGraph):
92
    r"""A sampling graph in CSC format."""
93
94
95
96
97
98
99

    def __repr__(self):
        return _csc_sampling_graph_str(self)

    def __init__(
        self, c_csc_graph: torch.ScriptObject, metadata: Optional[GraphMetadata]
    ):
100
        super().__init__()
101
102
103
104
        self._c_csc_graph = c_csc_graph
        self._metadata = metadata

    @property
105
    def total_num_nodes(self) -> int:
106
107
108
109
110
111
112
113
114
115
        """Returns the number of nodes in the graph.

        Returns
        -------
        int
            The number of rows in the dense format.
        """
        return self._c_csc_graph.num_nodes()

    @property
116
    def total_num_edges(self) -> int:
117
118
119
120
121
122
123
124
125
        """Returns the number of edges in the graph.

        Returns
        -------
        int
            The number of edges in the graph.
        """
        return self._c_csc_graph.num_edges()

126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    @property
    def num_nodes(self) -> Union[int, Dict[str, int]]:
        """The number of nodes in the graph.
        - If the graph is homogenous, returns an integer.
        - If the graph is heterogenous, returns a dictionary.

        Returns
        -------
        Union[int, Dict[str, int]]
            The number of nodes. Integer indicates the total nodes number of a
            homogenous graph; dict indicates nodes number per node types of a
            heterogenous graph.

        Examples
        --------
        >>> import dgl.graphbolt as gb, torch
        >>> total_num_nodes = 5
        >>> total_num_edges = 12
        >>> ntypes = {"N0": 0, "N1": 1}
        >>> etypes = {"N0:R0:N0": 0, "N0:R1:N1": 1,
        ... "N1:R2:N0": 2, "N1:R3:N1": 3}
        >>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
        >>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor(
        ... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
        >>> metadata = gb.GraphMetadata(ntypes, etypes)
153
        >>> graph = gb.from_fused_csc(indptr, indices, node_type_offset,
154
155
        ... type_per_edge, None, metadata)
        >>> print(graph.num_nodes)
156
        {'N0': 2, 'N1': 3}
157
158
159
160
161
162
163
164
165
166
167
        """

        offset = self.node_type_offset

        # Homogenous.
        if offset is None or self.metadata is None:
            return self._c_csc_graph.num_nodes()

        # Heterogenous
        else:
            num_nodes_per_type = {
168
                _type: (offset[_idx + 1] - offset[_idx]).item()
169
170
171
172
173
                for _type, _idx in self.metadata.node_type_to_id.items()
            }

            return num_nodes_per_type

174
175
176
177
178
179
180
181
    @property
    def csc_indptr(self) -> torch.tensor:
        """Returns the indices pointer in the CSC graph.

        Returns
        -------
        torch.tensor
            The indices pointer in the CSC graph. An integer tensor with
182
            shape `(total_num_nodes+1,)`.
183
184
185
        """
        return self._c_csc_graph.csc_indptr()

186
187
188
189
190
    @csc_indptr.setter
    def csc_indptr(self, csc_indptr: torch.tensor) -> None:
        """Sets the indices pointer in the CSC graph."""
        self._c_csc_graph.set_csc_indptr(csc_indptr)

191
192
193
194
195
196
197
198
    @property
    def indices(self) -> torch.tensor:
        """Returns the indices in the CSC graph.

        Returns
        -------
        torch.tensor
            The indices in the CSC graph. An integer tensor with shape
199
            `(total_num_edges,)`.
200
201
202
203
204
205
206
207

        Notes
        -------
        It is assumed that edges of each node are already sorted by edge type
        ids.
        """
        return self._c_csc_graph.indices()

208
209
210
211
212
    @indices.setter
    def indices(self, indices: torch.tensor) -> None:
        """Sets the indices in the CSC graph."""
        self._c_csc_graph.set_indices(indices)

213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
    @property
    def node_type_offset(self) -> Optional[torch.Tensor]:
        """Returns the node type offset tensor if present.

        Returns
        -------
        torch.Tensor or None
            If present, returns a 1D integer tensor of shape
            `(num_node_types + 1,)`. The tensor is in ascending order as nodes
            of the same type have continuous IDs, and larger node IDs are
            paired with larger node type IDs. The first value is 0 and last
            value is the number of nodes. And nodes with IDs between
            `node_type_offset_[i]~node_type_offset_[i+1]` are of type id 'i'.

        """
        return self._c_csc_graph.node_type_offset()

230
231
232
233
234
235
236
    @node_type_offset.setter
    def node_type_offset(
        self, node_type_offset: Optional[torch.Tensor]
    ) -> None:
        """Sets the node type offset tensor if present."""
        self._c_csc_graph.set_node_type_offset(node_type_offset)

237
238
239
240
241
242
243
    @property
    def type_per_edge(self) -> Optional[torch.Tensor]:
        """Returns the edge type tensor if present.

        Returns
        -------
        torch.Tensor or None
244
            If present, returns a 1D integer tensor of shape (total_num_edges,)
245
246
247
248
            containing the type of each edge in the graph.
        """
        return self._c_csc_graph.type_per_edge()

249
250
251
252
253
    @type_per_edge.setter
    def type_per_edge(self, type_per_edge: Optional[torch.Tensor]) -> None:
        """Sets the edge type tensor if present."""
        self._c_csc_graph.set_type_per_edge(type_per_edge)

254
255
256
257
258
259
260
261
262
263
264
265
266
267
    @property
    def edge_attributes(self) -> Optional[Dict[str, torch.Tensor]]:
        """Returns the edge attributes dictionary.

        Returns
        -------
        torch.Tensor or None
            If present, returns a dictionary of edge attributes. Each key
            represents the attribute's name, while the corresponding value
            holds the attribute's specific value. The length of each value
            should match the total number of edges."
        """
        return self._c_csc_graph.edge_attributes()

268
269
270
271
272
273
274
    @edge_attributes.setter
    def edge_attributes(
        self, edge_attributes: Optional[Dict[str, torch.Tensor]]
    ) -> None:
        """Sets the edge attributes dictionary."""
        self._c_csc_graph.set_edge_attributes(edge_attributes)

275
276
277
278
279
280
281
282
283
284
285
    @property
    def metadata(self) -> Optional[GraphMetadata]:
        """Returns the metadata of the graph.

        Returns
        -------
        GraphMetadata or None
            If present, returns the metadata of the graph.
        """
        return self._metadata

286
287
288
289
290
291
292
293
294
295
296
297
298
    def in_subgraph(self, nodes: torch.Tensor) -> torch.ScriptObject:
        """Return the subgraph induced on the inbound edges of the given nodes.

        An in subgraph is equivalent to creating a new graph using the incoming
        edges of the given nodes.

        Parameters
        ----------
        nodes : torch.Tensor
            The nodes to form the subgraph which are type agnostic.

        Returns
        -------
299
        torch.classes.graphbolt.SampledSubgraph
300
301
302
303
304
305
306
307
            The in subgraph.
        """
        # Ensure nodes is 1-D tensor.
        assert nodes.dim() == 1, "Nodes should be 1-D tensor."
        # Ensure that there are no duplicate nodes.
        assert len(torch.unique(nodes)) == len(
            nodes
        ), "Nodes cannot have duplicate values."
308
        # TODO: change the result to 'FusedSampledSubgraphImpl'.
309
310
        return self._c_csc_graph.in_subgraph(nodes)

311
    def _convert_to_sampled_subgraph(
312
313
314
        self,
        C_sampled_subgraph: torch.ScriptObject,
    ):
315
        """An internal function used to convert a fused homogeneous sampled
316
        subgraph to general struct 'FusedSampledSubgraphImpl'."""
317
318
319
        column_num = (
            C_sampled_subgraph.indptr[1:] - C_sampled_subgraph.indptr[:-1]
        )
320
        column = C_sampled_subgraph.original_column_node_ids.repeat_interleave(
321
322
323
324
            column_num
        )
        row = C_sampled_subgraph.indices
        type_per_edge = C_sampled_subgraph.type_per_edge
325
326
327
328
329
330
331
332
333
        original_edge_ids = C_sampled_subgraph.original_edge_ids
        has_original_eids = (
            self.edge_attributes is not None
            and ORIGINAL_EDGE_ID in self.edge_attributes
        )
        if has_original_eids:
            original_edge_ids = self.edge_attributes[ORIGINAL_EDGE_ID][
                original_edge_ids
            ]
334
335
336
337
338
339
340
        if type_per_edge is None:
            # The sampled graph is already a homogeneous graph.
            node_pairs = (row, column)
        else:
            # The sampled graph is a fused homogenized graph, which need to be
            # converted to heterogeneous graphs.
            node_pairs = defaultdict(list)
341
            original_hetero_edge_ids = {}
342
            for etype, etype_id in self.metadata.edge_type_to_id.items():
343
                src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
344
345
346
347
348
349
350
351
                src_ntype_id = self.metadata.node_type_to_id[src_ntype]
                dst_ntype_id = self.metadata.node_type_to_id[dst_ntype]
                mask = type_per_edge == etype_id
                hetero_row = row[mask] - self.node_type_offset[src_ntype_id]
                hetero_column = (
                    column[mask] - self.node_type_offset[dst_ntype_id]
                )
                node_pairs[etype] = (hetero_row, hetero_column)
352
353
354
355
                if has_original_eids:
                    original_hetero_edge_ids[etype] = original_edge_ids[mask]
            if has_original_eids:
                original_edge_ids = original_hetero_edge_ids
356
        return FusedSampledSubgraphImpl(
357
358
            node_pairs=node_pairs, original_edge_ids=original_edge_ids
        )
359

360
361
362
363
364
365
366
    def _convert_to_homogeneous_nodes(self, nodes):
        homogeneous_nodes = []
        for ntype, ids in nodes.items():
            ntype_id = self.metadata.node_type_to_id[ntype]
            homogeneous_nodes.append(ids + self.node_type_offset[ntype_id])
        return torch.cat(homogeneous_nodes)

367
    def sample_neighbors(
368
369
370
371
372
        self,
        nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
        fanouts: torch.Tensor,
        replace: bool = False,
        probs_name: Optional[str] = None,
373
    ) -> FusedSampledSubgraphImpl:
374
375
376
377
378
379
380
        """Sample neighboring edges of the given nodes and return the induced
        subgraph.

        Parameters
        ----------
        nodes: torch.Tensor or Dict[str, torch.Tensor]
            IDs of the given seed nodes.
381
382
383
384
              - If `nodes` is a tensor: It means the graph is homogeneous
                graph, and ids inside are homogeneous ids.
              - If `nodes` is a dictionary: The keys should be node type and
                ids inside are heterogeneous ids.
385
386
387
388
389
390
391
392
393
394
        fanouts: torch.Tensor
            The number of edges to be sampled for each node with or without
            considering edge types.
              - When the length is 1, it indicates that the fanout applies to
                all neighbors of the node as a collective, regardless of the
                edge type.
              - Otherwise, the length should equal to the number of edge
                types, and each fanout value corresponds to a specific edge
                type of the nodes.
            The value of each fanout should be >= 0 or = -1.
395
396
397
398
399
              - When the value is -1, all neighbors (with non-zero probability,
                if weighted) will be sampled once regardless of replacement. It
                is equivalent to selecting all neighbors with non-zero
                probability when the fanout is >= the number of neighbors (and
                replace is set to false).
400
401
402
403
404
405
406
              - When the value is a non-negative integer, it serves as a
                minimum threshold for selecting neighbors.
        replace: bool
            Boolean indicating whether the sample is preformed with or
            without replacement. If True, a value can be selected multiple
            times. Otherwise, each value can be selected only once.
        probs_name: str, optional
407
408
            An optional string specifying the name of an edge attribute used.
            This attribute tensor should contain (unnormalized) probabilities
409
410
411
412
413
            corresponding to each neighboring edge of a node. It must be a 1D
            floating-point or boolean tensor, with the number of elements
            equalling the total number of edges.
        Returns
        -------
414
        FusedSampledSubgraphImpl
415
416
417
418
419
            The sampled subgraph.

        Examples
        --------
        >>> import dgl.graphbolt as gb
420
421
422
        >>> import torch
        >>> ntypes = {"n1": 0, "n2": 1}
        >>> etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
423
        >>> metadata = gb.GraphMetadata(ntypes, etypes)
424
425
426
427
        >>> indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
        >>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
428
        >>> graph = gb.from_fused_csc(indptr, indices, type_per_edge=type_per_edge,
429
430
        ...     node_type_offset=node_type_offset, metadata=metadata)
        >>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
431
432
433
        >>> fanouts = torch.tensor([1, 1])
        >>> subgraph = graph.sample_neighbors(nodes, fanouts)
        >>> print(subgraph.node_pairs)
434
435
        defaultdict(<class 'list'>, {'n1:e1:n2': (tensor([0]),
          tensor([0])), 'n2:e2:n1': (tensor([2]), tensor([0]))})
436
437
        """
        if isinstance(nodes, dict):
438
            nodes = self._convert_to_homogeneous_nodes(nodes)
439
440

        C_sampled_subgraph = self._sample_neighbors(
441
            nodes, fanouts, replace, probs_name
442
        )
443
444

        return self._convert_to_sampled_subgraph(C_sampled_subgraph)
445

446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
    def _check_sampler_arguments(self, nodes, fanouts, probs_name):
        assert nodes.dim() == 1, "Nodes should be 1-D tensor."
        assert fanouts.dim() == 1, "Fanouts should be 1-D tensor."
        expected_fanout_len = 1
        if self.metadata and self.metadata.edge_type_to_id:
            expected_fanout_len = len(self.metadata.edge_type_to_id)
        assert len(fanouts) in [
            expected_fanout_len,
            1,
        ], "Fanouts should have the same number of elements as etypes or \
            should have a length of 1."
        if fanouts.size(0) > 1:
            assert (
                self.type_per_edge is not None
            ), "To perform sampling for each edge type (when the length of \
                `fanouts` > 1), the graph must include edge type information."
        assert torch.all(
            (fanouts >= 0) | (fanouts == -1)
        ), "Fanouts should consist of values that are either -1 or \
            greater than or equal to 0."
        if probs_name:
            assert (
                probs_name in self.edge_attributes
            ), f"Unknown edge attribute '{probs_name}'."
            probs_or_mask = self.edge_attributes[probs_name]
            assert probs_or_mask.dim() == 1, "Probs should be 1-D tensor."
            assert (
473
                probs_or_mask.size(0) == self.total_num_edges
474
475
476
477
478
479
480
481
482
483
            ), "Probs should have the same number of elements as the number \
                of edges."
            assert probs_or_mask.dtype in [
                torch.bool,
                torch.float16,
                torch.bfloat16,
                torch.float32,
                torch.float64,
            ], "Probs should have a floating-point or boolean data type."

484
    def _sample_neighbors(
485
486
        self,
        nodes: torch.Tensor,
487
        fanouts: torch.Tensor,
488
        replace: bool = False,
489
        probs_name: Optional[str] = None,
490
    ) -> torch.ScriptObject:
491
492
493
494
495
496
497
        """Sample neighboring edges of the given nodes and return the induced
        subgraph.

        Parameters
        ----------
        nodes: torch.Tensor
            IDs of the given seed nodes.
498
499
500
501
        fanouts: torch.Tensor
            The number of edges to be sampled for each node with or without
            considering edge types.
              - When the length is 1, it indicates that the fanout applies to
502
503
                all neighbors of the node as a collective, regardless of the
                edge type.
504
              - Otherwise, the length should equal to the number of edge
505
506
                types, and each fanout value corresponds to a specific edge
                type of the nodes.
507
            The value of each fanout should be >= 0 or = -1.
508
509
510
511
512
              - When the value is -1, all neighbors (with non-zero probability,
                if weighted) will be sampled once regardless of replacement. It
                is equivalent to selecting all neighbors with non-zero
                probability when the fanout is >= the number of neighbors (and
                replace is set to false).
513
              - When the value is a non-negative integer, it serves as a
514
515
                minimum threshold for selecting neighbors.
        replace: bool
516
517
518
            Boolean indicating whether the sample is preformed with or
            without replacement. If True, a value can be selected multiple
            times. Otherwise, each value can be selected only once.
519
520
521
522
523
524
        probs_name: str, optional
            An optional string specifying the name of an edge attribute. This
            attribute tensor should contain (unnormalized) probabilities
            corresponding to each neighboring edge of a node. It must be a 1D
            floating-point or boolean tensor, with the number of elements
            equalling the total number of edges.
525
        Returns
Rhett Ying's avatar
Rhett Ying committed
526
        -------
527
        torch.classes.graphbolt.SampledSubgraph
528
            The sampled C subgraph.
529
530
        """
        # Ensure nodes is 1-D tensor.
531
        self._check_sampler_arguments(nodes, fanouts, probs_name)
532
        has_original_eids = (
533
534
535
            self.edge_attributes is not None
            and ORIGINAL_EDGE_ID in self.edge_attributes
        )
536
        return self._c_csc_graph.sample_neighbors(
537
538
539
540
541
542
            nodes,
            fanouts.tolist(),
            replace,
            False,
            has_original_eids,
            probs_name,
543
544
545
546
547
548
549
550
        )

    def sample_layer_neighbors(
        self,
        nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
        fanouts: torch.Tensor,
        replace: bool = False,
        probs_name: Optional[str] = None,
551
    ) -> FusedSampledSubgraphImpl:
552
        """Sample neighboring edges of the given nodes and return the induced
553
554
555
        subgraph via layer-neighbor sampling from the NeurIPS 2023 paper
        `Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs
        <https://arxiv.org/abs/2210.13339>`__
556
557
558
559
560

        Parameters
        ----------
        nodes: torch.Tensor or Dict[str, torch.Tensor]
            IDs of the given seed nodes.
561
562
563
564
              - If `nodes` is a tensor: It means the graph is homogeneous
                graph, and ids inside are homogeneous ids.
              - If `nodes` is a dictionary: The keys should be node type and
                ids inside are heterogeneous ids.
565
566
567
568
569
570
571
572
573
574
        fanouts: torch.Tensor
            The number of edges to be sampled for each node with or without
            considering edge types.
              - When the length is 1, it indicates that the fanout applies to
                all neighbors of the node as a collective, regardless of the
                edge type.
              - Otherwise, the length should equal to the number of edge
                types, and each fanout value corresponds to a specific edge
                type of the nodes.
            The value of each fanout should be >= 0 or = -1.
575
576
577
578
579
              - When the value is -1, all neighbors (with non-zero probability,
                if weighted) will be sampled once regardless of replacement. It
                is equivalent to selecting all neighbors with non-zero
                probability when the fanout is >= the number of neighbors (and
                replace is set to false).
580
581
582
583
584
585
586
587
588
589
590
591
592
593
              - When the value is a non-negative integer, it serves as a
                minimum threshold for selecting neighbors.
        replace: bool
            Boolean indicating whether the sample is preformed with or
            without replacement. If True, a value can be selected multiple
            times. Otherwise, each value can be selected only once.
        probs_name: str, optional
            An optional string specifying the name of an edge attribute. This
            attribute tensor should contain (unnormalized) probabilities
            corresponding to each neighboring edge of a node. It must be a 1D
            floating-point or boolean tensor, with the number of elements
            equalling the total number of edges.
        Returns
        -------
594
        FusedSampledSubgraphImpl
595
596
597
598
            The sampled subgraph.

        Examples
        --------
599
600
601
602
603
604
605
606
607
        >>> import dgl.graphbolt as gb
        >>> import torch
        >>> ntypes = {"n1": 0, "n2": 1}
        >>> etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
        >>> metadata = gb.GraphMetadata(ntypes, etypes)
        >>> indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
        >>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
        >>> node_type_offset = torch.LongTensor([0, 2, 5])
        >>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
608
        >>> graph = gb.from_fused_csc(indptr, indices, type_per_edge=type_per_edge,
609
610
611
612
613
614
615
        ...     node_type_offset=node_type_offset, metadata=metadata)
        >>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
        >>> fanouts = torch.tensor([1, 1])
        >>> subgraph = graph.sample_layer_neighbors(nodes, fanouts)
        >>> print(subgraph.node_pairs)
        defaultdict(<class 'list'>, {'n1:e1:n2': (tensor([1]),
          tensor([0])), 'n2:e2:n1': (tensor([2]), tensor([0]))})
616
617
618
619
620
        """
        if isinstance(nodes, dict):
            nodes = self._convert_to_homogeneous_nodes(nodes)

        self._check_sampler_arguments(nodes, fanouts, probs_name)
621
622
623
624
        has_original_eids = (
            self.edge_attributes is not None
            and ORIGINAL_EDGE_ID in self.edge_attributes
        )
625
        C_sampled_subgraph = self._c_csc_graph.sample_neighbors(
626
627
628
629
630
631
            nodes,
            fanouts.tolist(),
            replace,
            True,
            has_original_eids,
            probs_name,
632
        )
633
634

        return self._convert_to_sampled_subgraph(C_sampled_subgraph)
635

636
637
638
639
640
641
642
643
644
645
646
647
    def sample_negative_edges_uniform(
        self, edge_type, node_pairs, negative_ratio
    ):
        """
        Sample negative edges by randomly choosing negative source-destination
        pairs according to a uniform distribution. For each edge ``(u, v)``,
        it is supposed to generate `negative_ratio` pairs of negative edges
        ``(u, v')``, where ``v'`` is chosen uniformly from all the nodes in
        the graph.

        Parameters
        ----------
648
        edge_type: str
649
650
651
            The type of edges in the provided node_pairs. Any negative edges
            sampled will also have the same type. If set to None, it will be
            considered as a homogeneous graph.
652
        node_pairs : Tuple[Tensor, Tensor]
653
654
655
656
657
658
659
660
661
662
            A tuple of two 1D tensors that represent the source and destination
            of positive edges, with 'positive' indicating that these edges are
            present in the graph. It's important to note that within the
            context of a heterogeneous graph, the ids in these tensors signify
            heterogeneous ids.
        negative_ratio: int
            The ratio of the number of negative samples to positive samples.

        Returns
        -------
663
        Tuple[Tensor, Tensor]
664
665
666
667
668
669
670
            A tuple consisting of two 1D tensors represents the source and
            destination of negative edges. In the context of a heterogeneous
            graph, both the input nodes and the selected nodes are represented
            by heterogeneous IDs, and the formed edges are of the input type
            `edge_type`. Note that negative refers to false negatives, which
            means the edge could be present or not present in the graph.
        """
671
        if edge_type is not None:
672
673
674
675
            assert (
                self.node_type_offset is not None
            ), "The 'node_type_offset' array is necessary for performing \
                negative sampling by edge type."
676
            _, _, dst_node_type = etype_str_to_tuple(edge_type)
677
678
679
680
681
682
            dst_node_type_id = self.metadata.node_type_to_id[dst_node_type]
            max_node_id = (
                self.node_type_offset[dst_node_type_id + 1]
                - self.node_type_offset[dst_node_type_id]
            )
        else:
683
            max_node_id = self.total_num_nodes
684
685
686
687
688
689
        return self._c_csc_graph.sample_negative_edges_uniform(
            node_pairs,
            negative_ratio,
            max_node_id,
        )

690
691
692
693
694
695
696
697
698
699
    def copy_to_shared_memory(self, shared_memory_name: str):
        """Copy the graph to shared memory.

        Parameters
        ----------
        shared_memory_name : str
            Name of the shared memory.

        Returns
        -------
700
701
        FusedCSCSamplingGraph
            The copied FusedCSCSamplingGraph object on shared memory.
702
        """
703
        return FusedCSCSamplingGraph(
704
705
706
707
            self._c_csc_graph.copy_to_shared_memory(shared_memory_name),
            self._metadata,
        )

708
    def to(self, device: torch.device) -> None:  # pylint: disable=invalid-name
709
        """Copy `FusedCSCSamplingGraph` to the specified device."""
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729

        def _to(x, device):
            return x.to(device) if hasattr(x, "to") else x

        self.csc_indptr = recursive_apply(
            self.csc_indptr, lambda x: _to(x, device)
        )
        self.indices = recursive_apply(self.indices, lambda x: _to(x, device))
        self.node_type_offset = recursive_apply(
            self.node_type_offset, lambda x: _to(x, device)
        )
        self.type_per_edge = recursive_apply(
            self.type_per_edge, lambda x: _to(x, device)
        )
        self.edge_attributes = recursive_apply(
            self.edge_attributes, lambda x: _to(x, device)
        )

        return self

730

731
def from_fused_csc(
732
733
734
735
    csc_indptr: torch.Tensor,
    indices: torch.Tensor,
    node_type_offset: Optional[torch.tensor] = None,
    type_per_edge: Optional[torch.tensor] = None,
736
    edge_attributes: Optional[Dict[str, torch.tensor]] = None,
737
    metadata: Optional[GraphMetadata] = None,
738
739
) -> FusedCSCSamplingGraph:
    """Create a FusedCSCSamplingGraph object from a CSC representation.
740
741
742
743
744

    Parameters
    ----------
    csc_indptr : torch.Tensor
        Pointer to the start of each row in the `indices`. An integer tensor
745
        with shape `(total_num_nodes+1,)`.
746
747
    indices : torch.Tensor
        Column indices of the non-zero elements in the CSC graph. An integer
748
        tensor with shape `(total_num_edges,)`.
749
750
751
752
    node_type_offset : Optional[torch.tensor], optional
        Offset of node types in the graph, by default None.
    type_per_edge : Optional[torch.tensor], optional
        Type ids of each edge in the graph, by default None.
753
754
    edge_attributes: Optional[Dict[str, torch.tensor]], optional
        Edge attributes of the graph, by default None.
755
756
757
758
    metadata: Optional[GraphMetadata], optional
        Metadata of the graph, by default None.
    Returns
    -------
759
760
    FusedCSCSamplingGraph
        The created FusedCSCSamplingGraph object.
761
762
763
764
765
766
767
768
769
770

    Examples
    --------
    >>> ntypes = {'n1': 0, 'n2': 1, 'n3': 2}
    >>> etypes = {('n1', 'e1', 'n2'): 0, ('n1', 'e2', 'n3'): 1}
    >>> metadata = graphbolt.GraphMetadata(ntypes, etypes)
    >>> csc_indptr = torch.tensor([0, 2, 5, 7])
    >>> indices = torch.tensor([1, 3, 0, 1, 2, 0, 3])
    >>> node_type_offset = torch.tensor([0, 1, 2, 3])
    >>> type_per_edge = torch.tensor([0, 1, 0, 1, 1, 0, 0])
771
    >>> graph = graphbolt.from_fused_csc(csc_indptr, indices,
772
773
774
775
    ...             node_type_offset=node_type_offset,
    ...             type_per_edge=type_per_edge,
    ...             edge_attributes=None, metadata=metadata)
    None, metadata)
776
    >>> print(graph)
777
    FusedCSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]),
778
                     indices=tensor([1, 3, 0, 1, 2, 0, 3]),
779
                     total_num_nodes=3, total_num_edges=7)
780
781
782
783
784
    """
    if metadata and metadata.node_type_to_id and node_type_offset is not None:
        assert len(metadata.node_type_to_id) + 1 == node_type_offset.size(
            0
        ), "node_type_offset length should be |ntypes| + 1."
785
786
    return FusedCSCSamplingGraph(
        torch.ops.graphbolt.from_fused_csc(
787
788
789
790
791
            csc_indptr,
            indices,
            node_type_offset,
            type_per_edge,
            edge_attributes,
792
793
794
795
796
        ),
        metadata,
    )


797
798
799
def load_from_shared_memory(
    shared_memory_name: str,
    metadata: Optional[GraphMetadata] = None,
800
801
) -> FusedCSCSamplingGraph:
    """Load a FusedCSCSamplingGraph object from shared memory.
802
803
804
805
806
807
808
809

    Parameters
    ----------
    shared_memory_name : str
        Name of the shared memory.

    Returns
    -------
810
811
    FusedCSCSamplingGraph
        The loaded FusedCSCSamplingGraph object on shared memory.
812
    """
813
    return FusedCSCSamplingGraph(
814
815
816
817
818
        torch.ops.graphbolt.load_from_shared_memory(shared_memory_name),
        metadata,
    )


819
def _csc_sampling_graph_str(graph: FusedCSCSamplingGraph) -> str:
820
821
822
823
824
    """Internal function for converting a csc sampling graph to string
    representation.
    """
    csc_indptr_str = str(graph.csc_indptr)
    indices_str = str(graph.indices)
825
826
827
828
    meta_str = (
        f"total_num_nodes={graph.total_num_nodes}, total_num_edges="
        f"{graph.total_num_edges}"
    )
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
    prefix = f"{type(graph).__name__}("

    def _add_indent(_str, indent):
        lines = _str.split("\n")
        lines = [lines[0]] + [" " * indent + line for line in lines[1:]]
        return "\n".join(lines)

    final_str = (
        "csc_indptr="
        + _add_indent(csc_indptr_str, len("csc_indptr="))
        + ",\n"
        + "indices="
        + _add_indent(indices_str, len("indices="))
        + ",\n"
        + meta_str
        + ")"
    )

    final_str = prefix + _add_indent(final_str, len(prefix))
    return final_str
849
850


851
852
def load_fused_csc_sampling_graph(filename):
    """Load FusedCSCSamplingGraph from tar file."""
853
854
855
    with tempfile.TemporaryDirectory() as temp_dir:
        with tarfile.open(filename, "r") as archive:
            archive.extractall(temp_dir)
856
        graph_filename = os.path.join(temp_dir, "fused_csc_sampling_graph.pt")
857
        metadata_filename = os.path.join(temp_dir, "metadata.pt")
858
859
        return FusedCSCSamplingGraph(
            torch.ops.graphbolt.load_fused_csc_sampling_graph(graph_filename),
860
861
862
863
            torch.load(metadata_filename),
        )


864
865
def save_fused_csc_sampling_graph(graph, filename):
    """Save FusedCSCSamplingGraph to tar file."""
866
    with tempfile.TemporaryDirectory() as temp_dir:
867
868
        graph_filename = os.path.join(temp_dir, "fused_csc_sampling_graph.pt")
        torch.ops.graphbolt.save_fused_csc_sampling_graph(
869
870
871
872
873
874
875
876
877
878
879
            graph._c_csc_graph, graph_filename
        )
        metadata_filename = os.path.join(temp_dir, "metadata.pt")
        torch.save(graph.metadata, metadata_filename)
        with tarfile.open(filename, "w") as archive:
            archive.add(
                graph_filename, arcname=os.path.basename(graph_filename)
            )
            archive.add(
                metadata_filename, arcname=os.path.basename(metadata_filename)
            )
880
    print(f"FusedCSCSamplingGraph has been saved to {filename}.")
881
882


883
884
885
886
def from_dglgraph(
    g: DGLGraph,
    is_homogeneous: bool = False,
    include_original_edge_id: bool = False,
887
888
) -> FusedCSCSamplingGraph:
    """Convert a DGLGraph to FusedCSCSamplingGraph."""
889

890
    homo_g, ntype_count, _ = to_homogeneous(g, return_count=True)
891
892
893
894
895
896
897
898
899
900
901

    if is_homogeneous:
        metadata = None
    else:
        # Initialize metadata.
        node_type_to_id = {ntype: g.get_ntype_id(ntype) for ntype in g.ntypes}
        edge_type_to_id = {
            etype_tuple_to_str(etype): g.get_etype_id(etype)
            for etype in g.canonical_etypes
        }
        metadata = GraphMetadata(node_type_to_id, edge_type_to_id)
902
903

    # Obtain CSC matrix.
904
    indptr, indices, edge_ids = homo_g.adj_tensors("csc")
905
906
    ntype_count.insert(0, 0)
    node_type_offset = torch.cumsum(torch.LongTensor(ntype_count), 0)
907

908
909
    # Assign edge type according to the order of CSC matrix.
    type_per_edge = None if is_homogeneous else homo_g.edata[ETYPE][edge_ids]
910

911
912
913
914
    edge_attributes = {}
    if include_original_edge_id:
        # Assign edge attributes according to the original eids mapping.
        edge_attributes[ORIGINAL_EDGE_ID] = homo_g.edata[EID][edge_ids]
915

916
917
    return FusedCSCSamplingGraph(
        torch.ops.graphbolt.from_fused_csc(
918
919
920
921
            indptr,
            indices,
            node_type_offset,
            type_per_edge,
922
            edge_attributes,
923
924
925
        ),
        metadata,
    )