graph_services.py 31.1 KB
Newer Older
1
"""A set of graph services of getting subgraphs from DistGraph"""
2
from collections import namedtuple
3

4
import numpy as np
5

6
7
8
from .. import backend as F
from ..base import EID, NID
from ..convert import graph, heterograph
9
from ..sampling import sample_etype_neighbors as local_sample_etype_neighbors
10
from ..sampling import sample_neighbors as local_sample_neighbors
11
from ..subgraph import in_subgraph as local_in_subgraph
12
from ..utils import toindex
13
14
15
16
17
18
19
from .rpc import (
    Request,
    Response,
    recv_responses,
    register_service,
    send_requests_to_machine,
)
Jinjing Zhou's avatar
Jinjing Zhou committed
20

21
__all__ = [
22
23
24
25
    "sample_neighbors",
    "sample_etype_neighbors",
    "in_subgraph",
    "find_edges",
26
]
Jinjing Zhou's avatar
Jinjing Zhou committed
27
28

SAMPLING_SERVICE_ID = 6657
29
INSUBGRAPH_SERVICE_ID = 6658
30
EDGES_SERVICE_ID = 6659
31
32
OUTDEGREE_SERVICE_ID = 6660
INDEGREE_SERVICE_ID = 6661
33
ETYPE_SAMPLING_SERVICE_ID = 6662
Jinjing Zhou's avatar
Jinjing Zhou committed
34

35

36
37
class SubgraphResponse(Response):
    """The response for sampling and in_subgraph"""
Jinjing Zhou's avatar
Jinjing Zhou committed
38
39
40
41
42
43
44
45
46
47
48
49

    def __init__(self, global_src, global_dst, global_eids):
        self.global_src = global_src
        self.global_dst = global_dst
        self.global_eids = global_eids

    def __setstate__(self, state):
        self.global_src, self.global_dst, self.global_eids = state

    def __getstate__(self):
        return self.global_src, self.global_dst, self.global_eids

50

51
52
53
54
55
56
57
58
59
60
61
62
63
class FindEdgeResponse(Response):
    """The response for sampling and in_subgraph"""

    def __init__(self, global_src, global_dst, order_id):
        self.global_src = global_src
        self.global_dst = global_dst
        self.order_id = order_id

    def __setstate__(self, state):
        self.global_src, self.global_dst, self.order_id = state

    def __getstate__(self):
        return self.global_src, self.global_dst, self.order_id
Jinjing Zhou's avatar
Jinjing Zhou committed
64

65
66
67
68
69

def _sample_neighbors(
    local_g, partition_book, seed_nodes, fan_out, edge_dir, prob, replace
):
    """Sample from local partition.
70

71
72
    The input nodes use global IDs. We need to map the global node IDs to local node IDs,
    perform sampling and map the sampled results to the global IDs space again.
73
    The sampled results are stored in three vectors that store source nodes, destination nodes
74
    and edge IDs.
75
76
77
78
79
    """
    local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid)
    local_ids = F.astype(local_ids, local_g.idtype)
    # local_ids = self.seed_nodes
    sampled_graph = local_sample_neighbors(
80
81
82
83
84
85
86
87
        local_g,
        local_ids,
        fan_out,
        edge_dir,
        prob,
        replace,
        _dist_training=True,
    )
88
89
    global_nid_mapping = local_g.ndata[NID]
    src, dst = sampled_graph.edges()
90
91
92
    global_src, global_dst = F.gather_row(
        global_nid_mapping, src
    ), F.gather_row(global_nid_mapping, dst)
93
94
95
    global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
    return global_src, global_dst, global_eids

96
97
98
99
100
101
102
103
104
105
106
107
108

def _sample_etype_neighbors(
    local_g,
    partition_book,
    seed_nodes,
    etype_field,
    fan_out,
    edge_dir,
    prob,
    replace,
    etype_sorted=False,
):
    """Sample from local partition.
109
110
111
112
113
114
115
116

    The input nodes use global IDs. We need to map the global node IDs to local node IDs,
    perform sampling and map the sampled results to the global IDs space again.
    The sampled results are stored in three vectors that store source nodes, destination nodes
    and edge IDs.
    """
    local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid)
    local_ids = F.astype(local_ids, local_g.idtype)
117

118
    sampled_graph = local_sample_etype_neighbors(
119
120
121
122
123
124
125
126
127
128
        local_g,
        local_ids,
        etype_field,
        fan_out,
        edge_dir,
        prob,
        replace,
        etype_sorted=etype_sorted,
        _dist_training=True,
    )
129
130
    global_nid_mapping = local_g.ndata[NID]
    src, dst = sampled_graph.edges()
131
132
133
    global_src, global_dst = F.gather_row(
        global_nid_mapping, src
    ), F.gather_row(global_nid_mapping, dst)
134
135
136
    global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
    return global_src, global_dst, global_eids

137

138
139
def _find_edges(local_g, partition_book, seed_edges):
    """Given an edge ID array, return the source
140
    and destination node ID array ``s`` and ``d`` in the local partition.
141
142
143
144
145
146
147
148
    """
    local_eids = partition_book.eid2localeid(seed_edges, partition_book.partid)
    local_eids = F.astype(local_eids, local_g.idtype)
    local_src, local_dst = local_g.find_edges(local_eids)
    global_nid_mapping = local_g.ndata[NID]
    global_src = global_nid_mapping[local_src]
    global_dst = global_nid_mapping[local_dst]
    return global_src, global_dst
149

150

151
def _in_degrees(local_g, partition_book, n):
152
    """Get in-degree of the nodes in the local partition."""
153
154
155
156
    local_nids = partition_book.nid2localnid(n, partition_book.partid)
    local_nids = F.astype(local_nids, local_g.idtype)
    return local_g.in_degrees(local_nids)

157

158
def _out_degrees(local_g, partition_book, n):
159
    """Get out-degree of the nodes in the local partition."""
160
161
162
163
    local_nids = partition_book.nid2localnid(n, partition_book.partid)
    local_nids = F.astype(local_nids, local_g.idtype)
    return local_g.out_degrees(local_nids)

164

165
def _in_subgraph(local_g, partition_book, seed_nodes):
166
    """Get in subgraph from local partition.
167

168
169
    The input nodes use global IDs. We need to map the global node IDs to local node IDs,
    get in-subgraph and map the sampled results to the global IDs space again.
170
    The results are stored in three vectors that store source nodes, destination nodes
171
    and edge IDs.
172
173
174
175
176
177
178
179
180
181
182
183
    """
    local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid)
    local_ids = F.astype(local_ids, local_g.idtype)
    # local_ids = self.seed_nodes
    sampled_graph = local_in_subgraph(local_g, local_ids)
    global_nid_mapping = local_g.ndata[NID]
    src, dst = sampled_graph.edges()
    global_src, global_dst = global_nid_mapping[src], global_nid_mapping[dst]
    global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
    return global_src, global_dst, global_eids


Jinjing Zhou's avatar
Jinjing Zhou committed
184
185
186
class SamplingRequest(Request):
    """Sampling Request"""

187
    def __init__(self, nodes, fan_out, edge_dir="in", prob=None, replace=False):
Jinjing Zhou's avatar
Jinjing Zhou committed
188
189
190
191
192
193
194
        self.seed_nodes = nodes
        self.edge_dir = edge_dir
        self.prob = prob
        self.replace = replace
        self.fan_out = fan_out

    def __setstate__(self, state):
195
196
197
198
199
200
201
        (
            self.seed_nodes,
            self.edge_dir,
            self.prob,
            self.replace,
            self.fan_out,
        ) = state
Jinjing Zhou's avatar
Jinjing Zhou committed
202
203

    def __getstate__(self):
204
205
206
207
208
209
210
        return (
            self.seed_nodes,
            self.edge_dir,
            self.prob,
            self.replace,
            self.fan_out,
        )
Jinjing Zhou's avatar
Jinjing Zhou committed
211
212
213
214

    def process_request(self, server_state):
        local_g = server_state.graph
        partition_book = server_state.partition_book
215
216
217
218
219
220
221
222
223
        global_src, global_dst, global_eids = _sample_neighbors(
            local_g,
            partition_book,
            self.seed_nodes,
            self.fan_out,
            self.edge_dir,
            self.prob,
            self.replace,
        )
224
225
        return SubgraphResponse(global_src, global_dst, global_eids)

226

227
228
229
class SamplingRequestEtype(Request):
    """Sampling Request"""

230
231
232
233
234
235
236
237
238
239
    def __init__(
        self,
        nodes,
        etype_field,
        fan_out,
        edge_dir="in",
        prob=None,
        replace=False,
        etype_sorted=True,
    ):
240
241
242
243
244
245
        self.seed_nodes = nodes
        self.edge_dir = edge_dir
        self.prob = prob
        self.replace = replace
        self.fan_out = fan_out
        self.etype_field = etype_field
246
        self.etype_sorted = etype_sorted
247
248

    def __setstate__(self, state):
249
250
251
252
253
254
255
256
257
        (
            self.seed_nodes,
            self.edge_dir,
            self.prob,
            self.replace,
            self.fan_out,
            self.etype_field,
            self.etype_sorted,
        ) = state
258
259

    def __getstate__(self):
260
261
262
263
264
265
266
267
268
        return (
            self.seed_nodes,
            self.edge_dir,
            self.prob,
            self.replace,
            self.fan_out,
            self.etype_field,
            self.etype_sorted,
        )
269
270
271
272

    def process_request(self, server_state):
        local_g = server_state.graph
        partition_book = server_state.partition_book
273
274
275
276
277
278
279
280
281
282
283
        global_src, global_dst, global_eids = _sample_etype_neighbors(
            local_g,
            partition_book,
            self.seed_nodes,
            self.etype_field,
            self.fan_out,
            self.edge_dir,
            self.prob,
            self.replace,
            self.etype_sorted,
        )
284
285
        return SubgraphResponse(global_src, global_dst, global_eids)

286

287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
class EdgesRequest(Request):
    """Edges Request"""

    def __init__(self, edge_ids, order_id):
        self.edge_ids = edge_ids
        self.order_id = order_id

    def __setstate__(self, state):
        self.edge_ids, self.order_id = state

    def __getstate__(self):
        return self.edge_ids, self.order_id

    def process_request(self, server_state):
        local_g = server_state.graph
        partition_book = server_state.partition_book
303
304
305
        global_src, global_dst = _find_edges(
            local_g, partition_book, self.edge_ids
        )
306
307

        return FindEdgeResponse(global_src, global_dst, self.order_id)
308

309

310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
class InDegreeRequest(Request):
    """In-degree Request"""

    def __init__(self, n, order_id):
        self.n = n
        self.order_id = order_id

    def __setstate__(self, state):
        self.n, self.order_id = state

    def __getstate__(self):
        return self.n, self.order_id

    def process_request(self, server_state):
        local_g = server_state.graph
        partition_book = server_state.partition_book
        deg = _in_degrees(local_g, partition_book, self.n)

        return InDegreeResponse(deg, self.order_id)

330

331
332
333
334
335
336
337
338
339
340
341
342
343
class InDegreeResponse(Response):
    """The response for in-degree"""

    def __init__(self, deg, order_id):
        self.val = deg
        self.order_id = order_id

    def __setstate__(self, state):
        self.val, self.order_id = state

    def __getstate__(self):
        return self.val, self.order_id

344

345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
class OutDegreeRequest(Request):
    """Out-degree Request"""

    def __init__(self, n, order_id):
        self.n = n
        self.order_id = order_id

    def __setstate__(self, state):
        self.n, self.order_id = state

    def __getstate__(self):
        return self.n, self.order_id

    def process_request(self, server_state):
        local_g = server_state.graph
        partition_book = server_state.partition_book
        deg = _out_degrees(local_g, partition_book, self.n)

        return OutDegreeResponse(deg, self.order_id)

365

366
367
368
369
370
371
372
373
374
375
376
377
378
class OutDegreeResponse(Response):
    """The response for out-degree"""

    def __init__(self, deg, order_id):
        self.val = deg
        self.order_id = order_id

    def __setstate__(self, state):
        self.val, self.order_id = state

    def __getstate__(self):
        return self.val, self.order_id

379

380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
class InSubgraphRequest(Request):
    """InSubgraph Request"""

    def __init__(self, nodes):
        self.seed_nodes = nodes

    def __setstate__(self, state):
        self.seed_nodes = state

    def __getstate__(self):
        return self.seed_nodes

    def process_request(self, server_state):
        local_g = server_state.graph
        partition_book = server_state.partition_book
395
396
397
        global_src, global_dst, global_eids = _in_subgraph(
            local_g, partition_book, self.seed_nodes
        )
398
        return SubgraphResponse(global_src, global_dst, global_eids)
Jinjing Zhou's avatar
Jinjing Zhou committed
399
400
401
402


def merge_graphs(res_list, num_nodes):
    """Merge request from multiple servers"""
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
    if len(res_list) > 1:
        srcs = []
        dsts = []
        eids = []
        for res in res_list:
            srcs.append(res.global_src)
            dsts.append(res.global_dst)
            eids.append(res.global_eids)
        src_tensor = F.cat(srcs, 0)
        dst_tensor = F.cat(dsts, 0)
        eid_tensor = F.cat(eids, 0)
    else:
        src_tensor = res_list[0].global_src
        dst_tensor = res_list[0].global_dst
        eid_tensor = res_list[0].global_eids
418
    g = graph((src_tensor, dst_tensor), num_nodes=num_nodes)
Jinjing Zhou's avatar
Jinjing Zhou committed
419
420
421
    g.edata[EID] = eid_tensor
    return g

422
423
424
425
426

LocalSampledGraph = namedtuple(
    "LocalSampledGraph", "global_src global_dst global_eids"
)

Jinjing Zhou's avatar
Jinjing Zhou committed
427

428
def _distributed_access(g, nodes, issue_remote_req, local_access):
429
    """A routine that fetches local neighborhood of nodes from the distributed graph.
430

431
432
433
434
435
    The local neighborhood of some nodes are stored in the local machine and the other
    nodes have their neighborhood on remote machines. This code will issue remote
    access requests first before fetching data from the local machine. In the end,
    we combine the data from the local machine and remote machines.
    In this way, we can hide the latency of accessing data on remote machines.
436
437
438
439

    Parameters
    ----------
    g : DistGraph
440
441
442
443
444
445
446
        The distributed graph
    nodes : tensor
        The nodes whose neighborhood are to be fetched.
    issue_remote_req : callable
        The function that issues requests to access remote data.
    local_access : callable
        The function that reads data on the local machine.
447
448
449
450

    Returns
    -------
    DGLHeteroGraph
451
        The subgraph that contains the neighborhoods of all input nodes.
452
    """
Jinjing Zhou's avatar
Jinjing Zhou committed
453
    req_list = []
454
    partition_book = g.get_partition_book()
455
456
457
    nodes = toindex(nodes).tousertensor()
    partition_id = partition_book.nid2partid(nodes)
    local_nids = None
458
    for pid in range(partition_book.num_partitions()):
459
460
461
462
463
        node_id = F.boolean_mask(nodes, partition_id == pid)
        # We optimize the sampling on a local partition if the server and the client
        # run on the same machine. With a good partitioning, most of the seed nodes
        # should reside in the local partition. If the server and the client
        # are not co-located, the client doesn't have a local partition.
464
        if pid == partition_book.partid and g.local_partition is not None:
465
466
467
            assert local_nids is None
            local_nids = node_id
        elif len(node_id) != 0:
468
            req = issue_remote_req(node_id)
Jinjing Zhou's avatar
Jinjing Zhou committed
469
            req_list.append((pid, req))
470
471
472
473
474
475
476
477
478

    # send requests to the remote machine.
    msgseq2pos = None
    if len(req_list) > 0:
        msgseq2pos = send_requests_to_machine(req_list)

    # sample neighbors for the nodes in the local partition.
    res_list = []
    if local_nids is not None:
479
480
481
        src, dst, eids = local_access(
            g.local_partition, partition_book, local_nids
        )
482
483
484
485
486
487
488
        res_list.append(LocalSampledGraph(src, dst, eids))

    # receive responses from remote machines.
    if msgseq2pos is not None:
        results = recv_responses(msgseq2pos)
        res_list.extend(results)

489
    sampled_graph = merge_graphs(res_list, g.number_of_nodes())
Jinjing Zhou's avatar
Jinjing Zhou committed
490
491
    return sampled_graph

492

493
def _frontier_to_heterogeneous_graph(g, frontier, gpb):
494
495
    # We need to handle empty frontiers correctly.
    if frontier.number_of_edges() == 0:
496
497
498
499
500
501
502
503
        data_dict = {
            etype: (np.zeros(0), np.zeros(0)) for etype in g.canonical_etypes
        }
        return heterograph(
            data_dict,
            {ntype: g.number_of_nodes(ntype) for ntype in g.ntypes},
            idtype=g.idtype,
        )
504

505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
    etype_ids, frontier.edata[EID] = gpb.map_to_per_etype(frontier.edata[EID])
    src, dst = frontier.edges()
    etype_ids, idx = F.sort_1d(etype_ids)
    src, dst = F.gather_row(src, idx), F.gather_row(dst, idx)
    eid = F.gather_row(frontier.edata[EID], idx)
    _, src = gpb.map_to_per_ntype(src)
    _, dst = gpb.map_to_per_ntype(dst)

    data_dict = dict()
    edge_ids = {}
    for etid in range(len(g.etypes)):
        etype = g.etypes[etid]
        canonical_etype = g.canonical_etypes[etid]
        type_idx = etype_ids == etid
        if F.sum(type_idx, 0) > 0:
520
521
522
523
            data_dict[canonical_etype] = (
                F.boolean_mask(src, type_idx),
                F.boolean_mask(dst, type_idx),
            )
524
            edge_ids[etype] = F.boolean_mask(eid, type_idx)
525
526
527
528
529
    hg = heterograph(
        data_dict,
        {ntype: g.number_of_nodes(ntype) for ntype in g.ntypes},
        idtype=g.idtype,
    )
530
531
532
533
534

    for etype in edge_ids:
        hg.edges[etype].data[EID] = edge_ids[etype]
    return hg

535
536
537
538
539
540
541
542
543
544
545

def sample_etype_neighbors(
    g,
    nodes,
    etype_field,
    fanout,
    edge_dir="in",
    prob=None,
    replace=False,
    etype_sorted=True,
):
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
    """Sample from the neighbors of the given nodes from a distributed graph.

    For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
    will be randomly chosen.  The returned graph will contain all the nodes in the
    original graph, but only the sampled edges.

    Node/edge features are not preserved. The original IDs of
    the sampled edges are stored as the `dgl.EID` feature in the returned graph.

    This function assumes the input is a homogeneous ``DGLGraph`` with the TRUE edge type
    information stored as the edge data in `etype_field`. The sampled subgraph is also
    stored in the homogeneous graph format. That is, all nodes and edges are assigned
    with unique IDs (in contrast, we typically use a type name and a node/edge ID to
    identify a node or an edge in ``DGLGraph``). We refer to this type of IDs
    as *homogeneous ID*.
    Users can use :func:`dgl.distributed.GraphPartitionBook.map_to_per_ntype`
    and :func:`dgl.distributed.GraphPartitionBook.map_to_per_etype`
    to identify their node/edge types and node/edge IDs of that type.

    Parameters
    ----------
    g : DistGraph
        The distributed graph..
    nodes : tensor or dict
        Node IDs to sample neighbors from. If it's a dict, it should contain only
        one key-value pair to make this API consistent with dgl.sampling.sample_neighbors.
    etype_field : string
        The field in g.edata storing the edge type.
574
575
576
    fanout : int or dict[etype, int]
        The number of edges to be sampled for each node per edge type.  If an integer
        is given, DGL assumes that the same fanout is applied to every edge type.
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597

        If -1 is given, all of the neighbors will be selected.
    edge_dir : str, optional
        Determines whether to sample inbound or outbound edges.

        Can take either ``in`` for inbound edges or ``out`` for outbound edges.
    prob : str, optional
        Feature name used as the (unnormalized) probabilities associated with each
        neighboring edge of a node.  The feature must have only one element for each
        edge.

        The features must be non-negative floats, and the sum of the features of
        inbound/outbound edges for every node must be positive (though they don't have
        to sum up to one).  Otherwise, the result will be undefined.
    replace : bool, optional
        If True, sample with replacement.

        When sampling with replacement, the sampled subgraph could have parallel edges.

        For sampling without replacement, if fanout > the number of neighbors, all the
        neighbors are sampled. If fanout == -1, all neighbors are collected.
598
599
    etype_sorted : bool, optional
        Indicates whether etypes are sorted.
600
601
602
603
604
605

    Returns
    -------
    DGLGraph
        A sampled subgraph containing only the sampled neighboring edges.  It is on CPU.
    """
606
607
608
609
610
    if isinstance(fanout, int):
        fanout = F.full_1d(len(g.etypes), fanout, F.int64, F.cpu())
    else:
        fanout = F.tensor([fanout[etype] for etype in g.etypes], dtype=F.int64)

611
612
613
614
    gpb = g.get_partition_book()
    if isinstance(nodes, dict):
        homo_nids = []
        for ntype in nodes.keys():
615
616
617
618
619
            assert (
                ntype in g.ntypes
            ), "The sampled node type {} does not exist in the input graph".format(
                ntype
            )
620
621
622
623
624
625
            if F.is_tensor(nodes[ntype]):
                typed_nodes = nodes[ntype]
            else:
                typed_nodes = toindex(nodes[ntype]).tousertensor()
            homo_nids.append(gpb.map_to_homo_nid(typed_nodes, ntype))
        nodes = F.cat(homo_nids, 0)
626

627
    def issue_remote_req(node_ids):
628
629
630
631
632
633
634
635
636
637
        return SamplingRequestEtype(
            node_ids,
            etype_field,
            fanout,
            edge_dir=edge_dir,
            prob=prob,
            replace=replace,
            etype_sorted=etype_sorted,
        )

638
    def local_access(local_g, partition_book, local_nids):
639
640
641
642
643
644
645
646
647
648
649
650
        return _sample_etype_neighbors(
            local_g,
            partition_book,
            local_nids,
            etype_field,
            fanout,
            edge_dir,
            prob,
            replace,
            etype_sorted=etype_sorted,
        )

651
    frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
652
    if not gpb.is_homogeneous:
653
        return _frontier_to_heterogeneous_graph(g, frontier, gpb)
654
655
656
    else:
        return frontier

657
658

def sample_neighbors(g, nodes, fanout, edge_dir="in", prob=None, replace=False):
659
660
    """Sample from the neighbors of the given nodes from a distributed graph.

661
662
663
    For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
    will be randomly chosen.  The returned graph will contain all the nodes in the
    original graph, but only the sampled edges.
664
665
666

    Node/edge features are not preserved. The original IDs of
    the sampled edges are stored as the `dgl.EID` feature in the returned graph.
Jinjing Zhou's avatar
Jinjing Zhou committed
667

668
669
    For heterogeneous graphs, ``nodes`` is a dictionary whose key is node type
    and the value is type-specific node IDs.
670
671
672
673

    Parameters
    ----------
    g : DistGraph
674
        The distributed graph..
Da Zheng's avatar
Da Zheng committed
675
    nodes : tensor or dict
676
        Node IDs to sample neighbors from. If it's a dict, it should contain only
Da Zheng's avatar
Da Zheng committed
677
        one key-value pair to make this API consistent with dgl.sampling.sample_neighbors.
678
    fanout : int
679
680
681
        The number of edges to be sampled for each node.

        If -1 is given, all of the neighbors will be selected.
682
    edge_dir : str, optional
683
684
685
        Determines whether to sample inbound or outbound edges.

        Can take either ``in`` for inbound edges or ``out`` for outbound edges.
686
    prob : str, optional
687
688
689
690
691
692
693
        Feature name used as the (unnormalized) probabilities associated with each
        neighboring edge of a node.  The feature must have only one element for each
        edge.

        The features must be non-negative floats, and the sum of the features of
        inbound/outbound edges for every node must be positive (though they don't have
        to sum up to one).  Otherwise, the result will be undefined.
694
695
696
    replace : bool, optional
        If True, sample with replacement.

697
698
699
700
701
        When sampling with replacement, the sampled subgraph could have parallel edges.

        For sampling without replacement, if fanout > the number of neighbors, all the
        neighbors are sampled. If fanout == -1, all neighbors are collected.

702
703
    Returns
    -------
704
705
    DGLGraph
        A sampled subgraph containing only the sampled neighboring edges.  It is on CPU.
706
    """
707
    gpb = g.get_partition_book()
708
    if not gpb.is_homogeneous:
709
        assert isinstance(nodes, dict)
710
711
        homo_nids = []
        for ntype in nodes:
712
713
714
            assert (
                ntype in g.ntypes
            ), "The sampled node type does not exist in the input graph"
715
716
717
718
719
720
            if F.is_tensor(nodes[ntype]):
                typed_nodes = nodes[ntype]
            else:
                typed_nodes = toindex(nodes[ntype]).tousertensor()
            homo_nids.append(gpb.map_to_homo_nid(typed_nodes, ntype))
        nodes = F.cat(homo_nids, 0)
721
722
723
724
    elif isinstance(nodes, dict):
        assert len(nodes) == 1
        nodes = list(nodes.values())[0]

725
    def issue_remote_req(node_ids):
726
727
728
729
        return SamplingRequest(
            node_ids, fanout, edge_dir=edge_dir, prob=prob, replace=replace
        )

730
    def local_access(local_g, partition_book, local_nids):
731
732
733
734
        return _sample_neighbors(
            local_g, partition_book, local_nids, fanout, edge_dir, prob, replace
        )

735
    frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
736
    if not gpb.is_homogeneous:
737
        return _frontier_to_heterogeneous_graph(g, frontier, gpb)
738
739
    else:
        return frontier
740

741

742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
def _distributed_edge_access(g, edges, issue_remote_req, local_access):
    """A routine that fetches local edges from distributed graph.

    The source and destination nodes of local edges are stored in the local
    machine and others are stored on remote machines. This code will issue
    remote access requests first before fetching data from the local machine.
    In the end, we combine the data from the local machine and remote machines.

    Parameters
    ----------
    g : DistGraph
        The distributed graph
    edges : tensor
        The edges to find their source and destination nodes.
    issue_remote_req : callable
        The function that issues requests to access remote data.
    local_access : callable
        The function that reads data on the local machine.

    Returns
    -------
    tensor
        The source node ID array.
    tensor
        The destination node ID array.
    """
    req_list = []
    partition_book = g.get_partition_book()
    edges = toindex(edges).tousertensor()
    partition_id = partition_book.eid2partid(edges)
    local_eids = None
    reorder_idx = []
    for pid in range(partition_book.num_partitions()):
775
        mask = partition_id == pid
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
        edge_id = F.boolean_mask(edges, mask)
        reorder_idx.append(F.nonzero_1d(mask))
        if pid == partition_book.partid and g.local_partition is not None:
            assert local_eids is None
            local_eids = edge_id
        elif len(edge_id) != 0:
            req = issue_remote_req(edge_id, pid)
            req_list.append((pid, req))

    # send requests to the remote machine.
    msgseq2pos = None
    if len(req_list) > 0:
        msgseq2pos = send_requests_to_machine(req_list)

    # handle edges in local partition.
    src_ids = F.zeros_like(edges)
    dst_ids = F.zeros_like(edges)
    if local_eids is not None:
        src, dst = local_access(g.local_partition, partition_book, local_eids)
795
796
797
798
799
800
        src_ids = F.scatter_row(
            src_ids, reorder_idx[partition_book.partid], src
        )
        dst_ids = F.scatter_row(
            dst_ids, reorder_idx[partition_book.partid], dst
        )
801
802
803
804
805
806
807
808
809
810
811

    # receive responses from remote machines.
    if msgseq2pos is not None:
        results = recv_responses(msgseq2pos)
        for result in results:
            src = result.global_src
            dst = result.global_dst
            src_ids = F.scatter_row(src_ids, reorder_idx[result.order_id], src)
            dst_ids = F.scatter_row(dst_ids, reorder_idx[result.order_id], dst)
    return src_ids, dst_ids

812

813
def find_edges(g, edge_ids):
814
    """Given an edge ID array, return the source and destination
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
    node ID array ``s`` and ``d`` from a distributed graph.
    ``s[i]`` and ``d[i]`` are source and destination node ID for
    edge ``eid[i]``.

    Parameters
    ----------
    g : DistGraph
        The distributed graph.
    edges : tensor
        The edge ID array.

    Returns
    -------
    tensor
        The source node ID array.
    tensor
        The destination node ID array.
    """
833

834
    def issue_remote_req(edge_ids, order_id):
835
        return EdgesRequest(edge_ids, order_id)
836

837
838
    def local_access(local_g, partition_book, edge_ids):
        return _find_edges(local_g, partition_book, edge_ids)
839

840
    return _distributed_edge_access(g, edge_ids, issue_remote_req, local_access)
841

842

843
def in_subgraph(g, nodes):
844
    """Return the subgraph induced on the inbound edges of the given nodes.
845

846
847
848
849
    The subgraph keeps the same type schema and all the nodes are preserved regardless
    of whether they have an edge or not.

    Node/edge features are not preserved. The original IDs of
850
851
852
853
    the extracted edges are stored as the `dgl.EID` feature in the returned graph.

    For now, we only support the input graph with one node type and one edge type.

854

855
856
857
858
    Parameters
    ----------
    g : DistGraph
        The distributed graph structure.
859
    nodes : tensor or dict
860
861
862
863
        Node ids to sample neighbors from.

    Returns
    -------
864
    DGLGraph
865
        The subgraph.
866
867
868

        One can retrieve the mapping from subgraph edge ID to parent
        edge ID via ``dgl.EID`` edge features of the subgraph.
869
    """
Da Zheng's avatar
Da Zheng committed
870
    if isinstance(nodes, dict):
871
872
873
        assert (
            len(nodes) == 1
        ), "The distributed in_subgraph only supports one node type for now."
Da Zheng's avatar
Da Zheng committed
874
        nodes = list(nodes.values())[0]
875

876
877
    def issue_remote_req(node_ids):
        return InSubgraphRequest(node_ids)
878

879
880
    def local_access(local_g, partition_book, local_nids):
        return _in_subgraph(local_g, partition_book, local_nids)
881

882
883
    return _distributed_access(g, nodes, issue_remote_req, local_access)

884

885
886
887
888
889
890
891
892
def _distributed_get_node_property(g, n, issue_remote_req, local_access):
    req_list = []
    partition_book = g.get_partition_book()
    n = toindex(n).tousertensor()
    partition_id = partition_book.nid2partid(n)
    local_nids = None
    reorder_idx = []
    for pid in range(partition_book.num_partitions()):
893
        mask = partition_id == pid
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
        nid = F.boolean_mask(n, mask)
        reorder_idx.append(F.nonzero_1d(mask))
        if pid == partition_book.partid and g.local_partition is not None:
            assert local_nids is None
            local_nids = nid
        elif len(nid) != 0:
            req = issue_remote_req(nid, pid)
            req_list.append((pid, req))

    # send requests to the remote machine.
    msgseq2pos = None
    if len(req_list) > 0:
        msgseq2pos = send_requests_to_machine(req_list)

    # handle edges in local partition.
    vals = None
    if local_nids is not None:
        local_vals = local_access(g.local_partition, partition_book, local_nids)
        shape = list(F.shape(local_vals))
        shape[0] = len(n)
        vals = F.zeros(shape, F.dtype(local_vals), F.cpu())
915
916
917
        vals = F.scatter_row(
            vals, reorder_idx[partition_book.partid], local_vals
        )
918
919
920
921
922
923
924
925
926
927
928
929
930

    # receive responses from remote machines.
    if msgseq2pos is not None:
        results = recv_responses(msgseq2pos)
        if len(results) > 0 and vals is None:
            shape = list(F.shape(results[0].val))
            shape[0] = len(n)
            vals = F.zeros(shape, F.dtype(results[0].val), F.cpu())
        for result in results:
            val = result.val
            vals = F.scatter_row(vals, reorder_idx[result.order_id], val)
    return vals

931

932
def in_degrees(g, v):
933
934
    """Get in-degrees"""

935
936
    def issue_remote_req(v, order_id):
        return InDegreeRequest(v, order_id)
937

938
939
    def local_access(local_g, partition_book, v):
        return _in_degrees(local_g, partition_book, v)
940

941
942
    return _distributed_get_node_property(g, v, issue_remote_req, local_access)

943

944
def out_degrees(g, u):
945
946
    """Get out-degrees"""

947
948
    def issue_remote_req(u, order_id):
        return OutDegreeRequest(u, order_id)
949

950
951
    def local_access(local_g, partition_book, u):
        return _out_degrees(local_g, partition_book, u)
952

953
954
    return _distributed_get_node_property(g, u, issue_remote_req, local_access)

955

956
register_service(SAMPLING_SERVICE_ID, SamplingRequest, SubgraphResponse)
957
register_service(EDGES_SERVICE_ID, EdgesRequest, FindEdgeResponse)
958
register_service(INSUBGRAPH_SERVICE_ID, InSubgraphRequest, SubgraphResponse)
959
960
register_service(OUTDEGREE_SERVICE_ID, OutDegreeRequest, OutDegreeResponse)
register_service(INDEGREE_SERVICE_ID, InDegreeRequest, InDegreeResponse)
961
962
963
register_service(
    ETYPE_SAMPLING_SERVICE_ID, SamplingRequestEtype, SubgraphResponse
)