graph_services.py 34 KB
Newer Older
1
"""A set of graph services of getting subgraphs from DistGraph"""
2
from collections import namedtuple
3

4
import numpy as np
5

6
7
8
from .. import backend as F
from ..base import EID, NID
from ..convert import graph, heterograph
9
10
11
12
from ..sampling import (
    sample_etype_neighbors as local_sample_etype_neighbors,
    sample_neighbors as local_sample_neighbors,
)
13
from ..subgraph import in_subgraph as local_in_subgraph
14
from ..utils import toindex
15
16
17
from .rpc import (
    recv_responses,
    register_service,
18
19
    Request,
    Response,
20
21
    send_requests_to_machine,
)
Jinjing Zhou's avatar
Jinjing Zhou committed
22

23
__all__ = [
24
25
26
27
    "sample_neighbors",
    "sample_etype_neighbors",
    "in_subgraph",
    "find_edges",
28
]
Jinjing Zhou's avatar
Jinjing Zhou committed
29
30

SAMPLING_SERVICE_ID = 6657
31
INSUBGRAPH_SERVICE_ID = 6658
32
EDGES_SERVICE_ID = 6659
33
34
OUTDEGREE_SERVICE_ID = 6660
INDEGREE_SERVICE_ID = 6661
35
ETYPE_SAMPLING_SERVICE_ID = 6662
Jinjing Zhou's avatar
Jinjing Zhou committed
36

37

38
39
class SubgraphResponse(Response):
    """The response for sampling and in_subgraph"""
Jinjing Zhou's avatar
Jinjing Zhou committed
40
41
42
43
44
45
46
47
48
49
50
51

    def __init__(self, global_src, global_dst, global_eids):
        self.global_src = global_src
        self.global_dst = global_dst
        self.global_eids = global_eids

    def __setstate__(self, state):
        self.global_src, self.global_dst, self.global_eids = state

    def __getstate__(self):
        return self.global_src, self.global_dst, self.global_eids

52

53
54
55
56
57
58
59
60
61
62
63
64
65
class FindEdgeResponse(Response):
    """The response for sampling and in_subgraph"""

    def __init__(self, global_src, global_dst, order_id):
        self.global_src = global_src
        self.global_dst = global_dst
        self.order_id = order_id

    def __setstate__(self, state):
        self.global_src, self.global_dst, self.order_id = state

    def __getstate__(self):
        return self.global_src, self.global_dst, self.order_id
Jinjing Zhou's avatar
Jinjing Zhou committed
66

67
68
69
70
71

def _sample_neighbors(
    local_g, partition_book, seed_nodes, fan_out, edge_dir, prob, replace
):
    """Sample from local partition.
72

73
74
    The input nodes use global IDs. We need to map the global node IDs to local node IDs,
    perform sampling and map the sampled results to the global IDs space again.
75
    The sampled results are stored in three vectors that store source nodes, destination nodes
76
    and edge IDs.
77
78
79
80
81
    """
    local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid)
    local_ids = F.astype(local_ids, local_g.idtype)
    # local_ids = self.seed_nodes
    sampled_graph = local_sample_neighbors(
82
83
84
85
86
87
88
89
        local_g,
        local_ids,
        fan_out,
        edge_dir,
        prob,
        replace,
        _dist_training=True,
    )
90
91
    global_nid_mapping = local_g.ndata[NID]
    src, dst = sampled_graph.edges()
92
93
94
    global_src, global_dst = F.gather_row(
        global_nid_mapping, src
    ), F.gather_row(global_nid_mapping, dst)
95
96
97
    global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
    return global_src, global_dst, global_eids

98
99
100
101
102

def _sample_etype_neighbors(
    local_g,
    partition_book,
    seed_nodes,
103
    etype_offset,
104
105
106
107
108
109
110
    fan_out,
    edge_dir,
    prob,
    replace,
    etype_sorted=False,
):
    """Sample from local partition.
111
112
113
114
115
116
117
118

    The input nodes use global IDs. We need to map the global node IDs to local node IDs,
    perform sampling and map the sampled results to the global IDs space again.
    The sampled results are stored in three vectors that store source nodes, destination nodes
    and edge IDs.
    """
    local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid)
    local_ids = F.astype(local_ids, local_g.idtype)
119

120
    sampled_graph = local_sample_etype_neighbors(
121
122
        local_g,
        local_ids,
123
        etype_offset,
124
125
126
127
128
129
130
        fan_out,
        edge_dir,
        prob,
        replace,
        etype_sorted=etype_sorted,
        _dist_training=True,
    )
131
132
    global_nid_mapping = local_g.ndata[NID]
    src, dst = sampled_graph.edges()
133
134
135
    global_src, global_dst = F.gather_row(
        global_nid_mapping, src
    ), F.gather_row(global_nid_mapping, dst)
136
137
138
    global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
    return global_src, global_dst, global_eids

139

140
141
def _find_edges(local_g, partition_book, seed_edges):
    """Given an edge ID array, return the source
142
    and destination node ID array ``s`` and ``d`` in the local partition.
143
144
145
146
147
148
149
150
    """
    local_eids = partition_book.eid2localeid(seed_edges, partition_book.partid)
    local_eids = F.astype(local_eids, local_g.idtype)
    local_src, local_dst = local_g.find_edges(local_eids)
    global_nid_mapping = local_g.ndata[NID]
    global_src = global_nid_mapping[local_src]
    global_dst = global_nid_mapping[local_dst]
    return global_src, global_dst
151

152

153
def _in_degrees(local_g, partition_book, n):
154
    """Get in-degree of the nodes in the local partition."""
155
156
157
158
    local_nids = partition_book.nid2localnid(n, partition_book.partid)
    local_nids = F.astype(local_nids, local_g.idtype)
    return local_g.in_degrees(local_nids)

159

160
def _out_degrees(local_g, partition_book, n):
161
    """Get out-degree of the nodes in the local partition."""
162
163
164
165
    local_nids = partition_book.nid2localnid(n, partition_book.partid)
    local_nids = F.astype(local_nids, local_g.idtype)
    return local_g.out_degrees(local_nids)

166

167
def _in_subgraph(local_g, partition_book, seed_nodes):
168
    """Get in subgraph from local partition.
169

170
171
    The input nodes use global IDs. We need to map the global node IDs to local node IDs,
    get in-subgraph and map the sampled results to the global IDs space again.
172
    The results are stored in three vectors that store source nodes, destination nodes
173
    and edge IDs.
174
175
176
177
178
179
180
181
182
183
184
185
    """
    local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid)
    local_ids = F.astype(local_ids, local_g.idtype)
    # local_ids = self.seed_nodes
    sampled_graph = local_in_subgraph(local_g, local_ids)
    global_nid_mapping = local_g.ndata[NID]
    src, dst = sampled_graph.edges()
    global_src, global_dst = global_nid_mapping[src], global_nid_mapping[dst]
    global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
    return global_src, global_dst, global_eids


186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
# --- NOTE 1 ---
# (BarclayII)
# If the sampling algorithm needs node and edge data, ideally the
# algorithm should query the underlying feature storage to get what it
# just needs to complete the job.  For instance, with
# sample_etype_neighbors, we only need the probability of the seed nodes'
# neighbors.
#
# However, right now we are reusing the existing subgraph sampling
# interfaces of DGLGraph (i.e. single machine solution), which needs
# the data of *all* the nodes/edges.  Going distributed, we now need
# the node/edge data of the *entire* local graph partition.
#
# If the sampling algorithm only use edge data, the current design works
# because the local graph partition contains all the in-edges of the
# assigned nodes as well as the data.  This is the case for
# sample_etype_neighbors.
#
# However, if the sampling algorithm requires data of the neighbor nodes
# (e.g. sample_neighbors_biased which performs biased sampling based on the
# type of the neighbor nodes), the current design will fail because the
# neighbor nodes (hence the data) may not belong to the current partition.
# This is a limitation of the current DistDGL design.  We should improve it
# later.

211

Jinjing Zhou's avatar
Jinjing Zhou committed
212
213
214
class SamplingRequest(Request):
    """Sampling Request"""

215
    def __init__(self, nodes, fan_out, edge_dir="in", prob=None, replace=False):
Jinjing Zhou's avatar
Jinjing Zhou committed
216
217
218
219
220
221
222
        self.seed_nodes = nodes
        self.edge_dir = edge_dir
        self.prob = prob
        self.replace = replace
        self.fan_out = fan_out

    def __setstate__(self, state):
223
224
225
226
227
228
229
        (
            self.seed_nodes,
            self.edge_dir,
            self.prob,
            self.replace,
            self.fan_out,
        ) = state
Jinjing Zhou's avatar
Jinjing Zhou committed
230
231

    def __getstate__(self):
232
233
234
235
236
237
238
        return (
            self.seed_nodes,
            self.edge_dir,
            self.prob,
            self.replace,
            self.fan_out,
        )
Jinjing Zhou's avatar
Jinjing Zhou committed
239
240
241
242

    def process_request(self, server_state):
        local_g = server_state.graph
        partition_book = server_state.partition_book
243
244
245
246
247
        kv_store = server_state.kv_store
        if self.prob is not None:
            prob = [kv_store.data_store[self.prob]]
        else:
            prob = None
248
249
250
251
252
253
        global_src, global_dst, global_eids = _sample_neighbors(
            local_g,
            partition_book,
            self.seed_nodes,
            self.fan_out,
            self.edge_dir,
254
            prob,
255
256
            self.replace,
        )
257
258
        return SubgraphResponse(global_src, global_dst, global_eids)

259

260
261
262
class SamplingRequestEtype(Request):
    """Sampling Request"""

263
264
265
266
267
268
269
270
271
    def __init__(
        self,
        nodes,
        fan_out,
        edge_dir="in",
        prob=None,
        replace=False,
        etype_sorted=True,
    ):
272
273
274
275
276
        self.seed_nodes = nodes
        self.edge_dir = edge_dir
        self.prob = prob
        self.replace = replace
        self.fan_out = fan_out
277
        self.etype_sorted = etype_sorted
278
279

    def __setstate__(self, state):
280
281
282
283
284
285
286
287
        (
            self.seed_nodes,
            self.edge_dir,
            self.prob,
            self.replace,
            self.fan_out,
            self.etype_sorted,
        ) = state
288
289

    def __getstate__(self):
290
291
292
293
294
295
296
297
        return (
            self.seed_nodes,
            self.edge_dir,
            self.prob,
            self.replace,
            self.fan_out,
            self.etype_sorted,
        )
298
299
300
301

    def process_request(self, server_state):
        local_g = server_state.graph
        partition_book = server_state.partition_book
302
303
304
305
306
307
308
309
310
311
        kv_store = server_state.kv_store
        etype_offset = partition_book.local_etype_offset
        # See NOTE 1
        if self.prob is not None:
            probs = [
                kv_store.data_store[key] if key != "" else None
                for key in self.prob
            ]
        else:
            probs = None
312
313
314
315
        global_src, global_dst, global_eids = _sample_etype_neighbors(
            local_g,
            partition_book,
            self.seed_nodes,
316
            etype_offset,
317
318
            self.fan_out,
            self.edge_dir,
319
            probs,
320
321
322
            self.replace,
            self.etype_sorted,
        )
323
324
        return SubgraphResponse(global_src, global_dst, global_eids)

325

326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
class EdgesRequest(Request):
    """Edges Request"""

    def __init__(self, edge_ids, order_id):
        self.edge_ids = edge_ids
        self.order_id = order_id

    def __setstate__(self, state):
        self.edge_ids, self.order_id = state

    def __getstate__(self):
        return self.edge_ids, self.order_id

    def process_request(self, server_state):
        local_g = server_state.graph
        partition_book = server_state.partition_book
342
343
344
        global_src, global_dst = _find_edges(
            local_g, partition_book, self.edge_ids
        )
345
346

        return FindEdgeResponse(global_src, global_dst, self.order_id)
347

348

349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
class InDegreeRequest(Request):
    """In-degree Request"""

    def __init__(self, n, order_id):
        self.n = n
        self.order_id = order_id

    def __setstate__(self, state):
        self.n, self.order_id = state

    def __getstate__(self):
        return self.n, self.order_id

    def process_request(self, server_state):
        local_g = server_state.graph
        partition_book = server_state.partition_book
        deg = _in_degrees(local_g, partition_book, self.n)

        return InDegreeResponse(deg, self.order_id)

369

370
371
372
373
374
375
376
377
378
379
380
381
382
class InDegreeResponse(Response):
    """The response for in-degree"""

    def __init__(self, deg, order_id):
        self.val = deg
        self.order_id = order_id

    def __setstate__(self, state):
        self.val, self.order_id = state

    def __getstate__(self):
        return self.val, self.order_id

383

384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
class OutDegreeRequest(Request):
    """Out-degree Request"""

    def __init__(self, n, order_id):
        self.n = n
        self.order_id = order_id

    def __setstate__(self, state):
        self.n, self.order_id = state

    def __getstate__(self):
        return self.n, self.order_id

    def process_request(self, server_state):
        local_g = server_state.graph
        partition_book = server_state.partition_book
        deg = _out_degrees(local_g, partition_book, self.n)

        return OutDegreeResponse(deg, self.order_id)

404

405
406
407
408
409
410
411
412
413
414
415
416
417
class OutDegreeResponse(Response):
    """The response for out-degree"""

    def __init__(self, deg, order_id):
        self.val = deg
        self.order_id = order_id

    def __setstate__(self, state):
        self.val, self.order_id = state

    def __getstate__(self):
        return self.val, self.order_id

418

419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
class InSubgraphRequest(Request):
    """InSubgraph Request"""

    def __init__(self, nodes):
        self.seed_nodes = nodes

    def __setstate__(self, state):
        self.seed_nodes = state

    def __getstate__(self):
        return self.seed_nodes

    def process_request(self, server_state):
        local_g = server_state.graph
        partition_book = server_state.partition_book
434
435
436
        global_src, global_dst, global_eids = _in_subgraph(
            local_g, partition_book, self.seed_nodes
        )
437
        return SubgraphResponse(global_src, global_dst, global_eids)
Jinjing Zhou's avatar
Jinjing Zhou committed
438
439
440
441


def merge_graphs(res_list, num_nodes):
    """Merge request from multiple servers"""
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
    if len(res_list) > 1:
        srcs = []
        dsts = []
        eids = []
        for res in res_list:
            srcs.append(res.global_src)
            dsts.append(res.global_dst)
            eids.append(res.global_eids)
        src_tensor = F.cat(srcs, 0)
        dst_tensor = F.cat(dsts, 0)
        eid_tensor = F.cat(eids, 0)
    else:
        src_tensor = res_list[0].global_src
        dst_tensor = res_list[0].global_dst
        eid_tensor = res_list[0].global_eids
457
    g = graph((src_tensor, dst_tensor), num_nodes=num_nodes)
Jinjing Zhou's avatar
Jinjing Zhou committed
458
459
460
    g.edata[EID] = eid_tensor
    return g

461
462
463
464
465

LocalSampledGraph = namedtuple(
    "LocalSampledGraph", "global_src global_dst global_eids"
)

Jinjing Zhou's avatar
Jinjing Zhou committed
466

467
def _distributed_access(g, nodes, issue_remote_req, local_access):
468
    """A routine that fetches local neighborhood of nodes from the distributed graph.
469

470
471
472
473
474
    The local neighborhood of some nodes are stored in the local machine and the other
    nodes have their neighborhood on remote machines. This code will issue remote
    access requests first before fetching data from the local machine. In the end,
    we combine the data from the local machine and remote machines.
    In this way, we can hide the latency of accessing data on remote machines.
475
476
477
478

    Parameters
    ----------
    g : DistGraph
479
480
481
482
483
484
485
        The distributed graph
    nodes : tensor
        The nodes whose neighborhood are to be fetched.
    issue_remote_req : callable
        The function that issues requests to access remote data.
    local_access : callable
        The function that reads data on the local machine.
486
487
488

    Returns
    -------
peizhou001's avatar
peizhou001 committed
489
    DGLGraph
490
        The subgraph that contains the neighborhoods of all input nodes.
491
    """
Jinjing Zhou's avatar
Jinjing Zhou committed
492
    req_list = []
493
    partition_book = g.get_partition_book()
494
495
496
    nodes = toindex(nodes).tousertensor()
    partition_id = partition_book.nid2partid(nodes)
    local_nids = None
497
    for pid in range(partition_book.num_partitions()):
498
499
500
501
502
        node_id = F.boolean_mask(nodes, partition_id == pid)
        # We optimize the sampling on a local partition if the server and the client
        # run on the same machine. With a good partitioning, most of the seed nodes
        # should reside in the local partition. If the server and the client
        # are not co-located, the client doesn't have a local partition.
503
        if pid == partition_book.partid and g.local_partition is not None:
504
505
506
            assert local_nids is None
            local_nids = node_id
        elif len(node_id) != 0:
507
            req = issue_remote_req(node_id)
Jinjing Zhou's avatar
Jinjing Zhou committed
508
            req_list.append((pid, req))
509
510
511
512
513
514
515
516
517

    # send requests to the remote machine.
    msgseq2pos = None
    if len(req_list) > 0:
        msgseq2pos = send_requests_to_machine(req_list)

    # sample neighbors for the nodes in the local partition.
    res_list = []
    if local_nids is not None:
518
519
520
        src, dst, eids = local_access(
            g.local_partition, partition_book, local_nids
        )
521
522
523
524
525
526
527
        res_list.append(LocalSampledGraph(src, dst, eids))

    # receive responses from remote machines.
    if msgseq2pos is not None:
        results = recv_responses(msgseq2pos)
        res_list.extend(results)

528
    sampled_graph = merge_graphs(res_list, g.number_of_nodes())
Jinjing Zhou's avatar
Jinjing Zhou committed
529
530
    return sampled_graph

531

532
def _frontier_to_heterogeneous_graph(g, frontier, gpb):
533
534
    # We need to handle empty frontiers correctly.
    if frontier.number_of_edges() == 0:
535
536
537
538
539
540
541
542
        data_dict = {
            etype: (np.zeros(0), np.zeros(0)) for etype in g.canonical_etypes
        }
        return heterograph(
            data_dict,
            {ntype: g.number_of_nodes(ntype) for ntype in g.ntypes},
            idtype=g.idtype,
        )
543

544
545
546
547
548
549
550
551
552
553
    etype_ids, frontier.edata[EID] = gpb.map_to_per_etype(frontier.edata[EID])
    src, dst = frontier.edges()
    etype_ids, idx = F.sort_1d(etype_ids)
    src, dst = F.gather_row(src, idx), F.gather_row(dst, idx)
    eid = F.gather_row(frontier.edata[EID], idx)
    _, src = gpb.map_to_per_ntype(src)
    _, dst = gpb.map_to_per_ntype(dst)

    data_dict = dict()
    edge_ids = {}
554
    for etid, etype in enumerate(g.canonical_etypes):
555
556
        type_idx = etype_ids == etid
        if F.sum(type_idx, 0) > 0:
557
            data_dict[etype] = (
558
559
560
                F.boolean_mask(src, type_idx),
                F.boolean_mask(dst, type_idx),
            )
561
            edge_ids[etype] = F.boolean_mask(eid, type_idx)
562
563
564
565
566
    hg = heterograph(
        data_dict,
        {ntype: g.number_of_nodes(ntype) for ntype in g.ntypes},
        idtype=g.idtype,
    )
567
568
569
570
571

    for etype in edge_ids:
        hg.edges[etype].data[EID] = edge_ids[etype]
    return hg

572
573
574
575
576
577
578
579
580
581

def sample_etype_neighbors(
    g,
    nodes,
    fanout,
    edge_dir="in",
    prob=None,
    replace=False,
    etype_sorted=True,
):
582
583
584
585
586
587
588
589
590
    """Sample from the neighbors of the given nodes from a distributed graph.

    For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
    will be randomly chosen.  The returned graph will contain all the nodes in the
    original graph, but only the sampled edges.

    Node/edge features are not preserved. The original IDs of
    the sampled edges are stored as the `dgl.EID` feature in the returned graph.

591
592
    This function assumes the input is a homogeneous ``DGLGraph`` with the edges
    ordered by their edge types. The sampled subgraph is also
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
    stored in the homogeneous graph format. That is, all nodes and edges are assigned
    with unique IDs (in contrast, we typically use a type name and a node/edge ID to
    identify a node or an edge in ``DGLGraph``). We refer to this type of IDs
    as *homogeneous ID*.
    Users can use :func:`dgl.distributed.GraphPartitionBook.map_to_per_ntype`
    and :func:`dgl.distributed.GraphPartitionBook.map_to_per_etype`
    to identify their node/edge types and node/edge IDs of that type.

    Parameters
    ----------
    g : DistGraph
        The distributed graph..
    nodes : tensor or dict
        Node IDs to sample neighbors from. If it's a dict, it should contain only
        one key-value pair to make this API consistent with dgl.sampling.sample_neighbors.
608
609
610
    fanout : int or dict[etype, int]
        The number of edges to be sampled for each node per edge type.  If an integer
        is given, DGL assumes that the same fanout is applied to every edge type.
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631

        If -1 is given, all of the neighbors will be selected.
    edge_dir : str, optional
        Determines whether to sample inbound or outbound edges.

        Can take either ``in`` for inbound edges or ``out`` for outbound edges.
    prob : str, optional
        Feature name used as the (unnormalized) probabilities associated with each
        neighboring edge of a node.  The feature must have only one element for each
        edge.

        The features must be non-negative floats, and the sum of the features of
        inbound/outbound edges for every node must be positive (though they don't have
        to sum up to one).  Otherwise, the result will be undefined.
    replace : bool, optional
        If True, sample with replacement.

        When sampling with replacement, the sampled subgraph could have parallel edges.

        For sampling without replacement, if fanout > the number of neighbors, all the
        neighbors are sampled. If fanout == -1, all neighbors are collected.
632
633
    etype_sorted : bool, optional
        Indicates whether etypes are sorted.
634
635
636
637
638
639

    Returns
    -------
    DGLGraph
        A sampled subgraph containing only the sampled neighboring edges.  It is on CPU.
    """
640
    if isinstance(fanout, int):
641
        fanout = F.full_1d(len(g.canonical_etypes), fanout, F.int64, F.cpu())
642
    else:
643
644
645
646
647
648
649
650
651
652
653
        etype_ids = {etype: i for i, etype in enumerate(g.canonical_etypes)}
        fanout_array = [None] * len(g.canonical_etypes)
        for etype, v in fanout.items():
            c_etype = g.to_canonical_etype(etype)
            fanout_array[etype_ids[c_etype]] = v
        assert all(v is not None for v in fanout_array), (
            "Not all etypes have valid fanout. Please make sure passed-in "
            "fanout in dict includes all the etypes in graph. Passed-in "
            f"fanout: {fanout}, graph etypes: {g.canonical_etypes}."
        )
        fanout = F.tensor(fanout_array, dtype=F.int64)
654

655
656
657
658
    gpb = g.get_partition_book()
    if isinstance(nodes, dict):
        homo_nids = []
        for ntype in nodes.keys():
659
660
661
662
663
            assert (
                ntype in g.ntypes
            ), "The sampled node type {} does not exist in the input graph".format(
                ntype
            )
664
665
666
667
668
669
            if F.is_tensor(nodes[ntype]):
                typed_nodes = nodes[ntype]
            else:
                typed_nodes = toindex(nodes[ntype]).tousertensor()
            homo_nids.append(gpb.map_to_homo_nid(typed_nodes, ntype))
        nodes = F.cat(homo_nids, 0)
670

671
    def issue_remote_req(node_ids):
672
673
674
675
676
677
678
679
        if prob is not None:
            # See NOTE 1
            _prob = [
                # NOTE (BarclayII)
                # Currently DistGraph.edges[] does not accept canonical etype.
                g.edges[etype].data[prob].kvstore_key
                if prob in g.edges[etype].data
                else ""
680
                for etype in g.canonical_etypes
681
682
683
            ]
        else:
            _prob = None
684
685
686
687
        return SamplingRequestEtype(
            node_ids,
            fanout,
            edge_dir=edge_dir,
688
            prob=_prob,
689
690
691
692
            replace=replace,
            etype_sorted=etype_sorted,
        )

693
    def local_access(local_g, partition_book, local_nids):
694
695
696
697
698
699
700
701
702
        etype_offset = gpb.local_etype_offset
        # See NOTE 1
        if prob is None:
            _prob = None
        else:
            _prob = [
                g.edges[etype].data[prob].local_partition
                if prob in g.edges[etype].data
                else None
703
                for etype in g.canonical_etypes
704
            ]
705
706
707
708
        return _sample_etype_neighbors(
            local_g,
            partition_book,
            local_nids,
709
            etype_offset,
710
711
            fanout,
            edge_dir,
712
            _prob,
713
714
715
716
            replace,
            etype_sorted=etype_sorted,
        )

717
    frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
718
    if not gpb.is_homogeneous:
719
        return _frontier_to_heterogeneous_graph(g, frontier, gpb)
720
721
722
    else:
        return frontier

723
724

def sample_neighbors(g, nodes, fanout, edge_dir="in", prob=None, replace=False):
725
726
    """Sample from the neighbors of the given nodes from a distributed graph.

727
728
729
    For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
    will be randomly chosen.  The returned graph will contain all the nodes in the
    original graph, but only the sampled edges.
730
731
732

    Node/edge features are not preserved. The original IDs of
    the sampled edges are stored as the `dgl.EID` feature in the returned graph.
Jinjing Zhou's avatar
Jinjing Zhou committed
733

734
735
    For heterogeneous graphs, ``nodes`` is a dictionary whose key is node type
    and the value is type-specific node IDs.
736
737
738
739

    Parameters
    ----------
    g : DistGraph
740
        The distributed graph..
Da Zheng's avatar
Da Zheng committed
741
    nodes : tensor or dict
742
        Node IDs to sample neighbors from. If it's a dict, it should contain only
Da Zheng's avatar
Da Zheng committed
743
        one key-value pair to make this API consistent with dgl.sampling.sample_neighbors.
744
    fanout : int
745
746
747
        The number of edges to be sampled for each node.

        If -1 is given, all of the neighbors will be selected.
748
    edge_dir : str, optional
749
750
751
        Determines whether to sample inbound or outbound edges.

        Can take either ``in`` for inbound edges or ``out`` for outbound edges.
752
    prob : str, optional
753
754
755
756
757
758
759
        Feature name used as the (unnormalized) probabilities associated with each
        neighboring edge of a node.  The feature must have only one element for each
        edge.

        The features must be non-negative floats, and the sum of the features of
        inbound/outbound edges for every node must be positive (though they don't have
        to sum up to one).  Otherwise, the result will be undefined.
760
761
762
    replace : bool, optional
        If True, sample with replacement.

763
764
765
766
767
        When sampling with replacement, the sampled subgraph could have parallel edges.

        For sampling without replacement, if fanout > the number of neighbors, all the
        neighbors are sampled. If fanout == -1, all neighbors are collected.

768
769
    Returns
    -------
770
771
    DGLGraph
        A sampled subgraph containing only the sampled neighboring edges.  It is on CPU.
772
    """
773
    gpb = g.get_partition_book()
774
    if not gpb.is_homogeneous:
775
        assert isinstance(nodes, dict)
776
777
        homo_nids = []
        for ntype in nodes:
778
779
780
            assert (
                ntype in g.ntypes
            ), "The sampled node type does not exist in the input graph"
781
782
783
784
785
786
            if F.is_tensor(nodes[ntype]):
                typed_nodes = nodes[ntype]
            else:
                typed_nodes = toindex(nodes[ntype]).tousertensor()
            homo_nids.append(gpb.map_to_homo_nid(typed_nodes, ntype))
        nodes = F.cat(homo_nids, 0)
787
788
789
790
    elif isinstance(nodes, dict):
        assert len(nodes) == 1
        nodes = list(nodes.values())[0]

791
    def issue_remote_req(node_ids):
792
793
794
795
796
        if prob is not None:
            # See NOTE 1
            _prob = g.edata[prob].kvstore_key
        else:
            _prob = None
797
        return SamplingRequest(
798
            node_ids, fanout, edge_dir=edge_dir, prob=_prob, replace=replace
799
800
        )

801
    def local_access(local_g, partition_book, local_nids):
802
        # See NOTE 1
803
        _prob = [g.edata[prob].local_partition] if prob is not None else None
804
        return _sample_neighbors(
805
806
807
808
809
810
811
            local_g,
            partition_book,
            local_nids,
            fanout,
            edge_dir,
            _prob,
            replace,
812
813
        )

814
    frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
815
    if not gpb.is_homogeneous:
816
        return _frontier_to_heterogeneous_graph(g, frontier, gpb)
817
818
    else:
        return frontier
819

820

821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
def _distributed_edge_access(g, edges, issue_remote_req, local_access):
    """A routine that fetches local edges from distributed graph.

    The source and destination nodes of local edges are stored in the local
    machine and others are stored on remote machines. This code will issue
    remote access requests first before fetching data from the local machine.
    In the end, we combine the data from the local machine and remote machines.

    Parameters
    ----------
    g : DistGraph
        The distributed graph
    edges : tensor
        The edges to find their source and destination nodes.
    issue_remote_req : callable
        The function that issues requests to access remote data.
    local_access : callable
        The function that reads data on the local machine.

    Returns
    -------
    tensor
        The source node ID array.
    tensor
        The destination node ID array.
    """
    req_list = []
    partition_book = g.get_partition_book()
    edges = toindex(edges).tousertensor()
    partition_id = partition_book.eid2partid(edges)
    local_eids = None
    reorder_idx = []
    for pid in range(partition_book.num_partitions()):
854
        mask = partition_id == pid
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
        edge_id = F.boolean_mask(edges, mask)
        reorder_idx.append(F.nonzero_1d(mask))
        if pid == partition_book.partid and g.local_partition is not None:
            assert local_eids is None
            local_eids = edge_id
        elif len(edge_id) != 0:
            req = issue_remote_req(edge_id, pid)
            req_list.append((pid, req))

    # send requests to the remote machine.
    msgseq2pos = None
    if len(req_list) > 0:
        msgseq2pos = send_requests_to_machine(req_list)

    # handle edges in local partition.
    src_ids = F.zeros_like(edges)
    dst_ids = F.zeros_like(edges)
    if local_eids is not None:
        src, dst = local_access(g.local_partition, partition_book, local_eids)
874
875
876
877
878
879
        src_ids = F.scatter_row(
            src_ids, reorder_idx[partition_book.partid], src
        )
        dst_ids = F.scatter_row(
            dst_ids, reorder_idx[partition_book.partid], dst
        )
880
881
882
883
884
885
886
887
888
889
890

    # receive responses from remote machines.
    if msgseq2pos is not None:
        results = recv_responses(msgseq2pos)
        for result in results:
            src = result.global_src
            dst = result.global_dst
            src_ids = F.scatter_row(src_ids, reorder_idx[result.order_id], src)
            dst_ids = F.scatter_row(dst_ids, reorder_idx[result.order_id], dst)
    return src_ids, dst_ids

891

892
def find_edges(g, edge_ids):
893
    """Given an edge ID array, return the source and destination
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
    node ID array ``s`` and ``d`` from a distributed graph.
    ``s[i]`` and ``d[i]`` are source and destination node ID for
    edge ``eid[i]``.

    Parameters
    ----------
    g : DistGraph
        The distributed graph.
    edges : tensor
        The edge ID array.

    Returns
    -------
    tensor
        The source node ID array.
    tensor
        The destination node ID array.
    """
912

913
    def issue_remote_req(edge_ids, order_id):
914
        return EdgesRequest(edge_ids, order_id)
915

916
917
    def local_access(local_g, partition_book, edge_ids):
        return _find_edges(local_g, partition_book, edge_ids)
918

919
    return _distributed_edge_access(g, edge_ids, issue_remote_req, local_access)
920

921

922
def in_subgraph(g, nodes):
923
    """Return the subgraph induced on the inbound edges of the given nodes.
924

925
926
927
928
    The subgraph keeps the same type schema and all the nodes are preserved regardless
    of whether they have an edge or not.

    Node/edge features are not preserved. The original IDs of
929
930
931
932
    the extracted edges are stored as the `dgl.EID` feature in the returned graph.

    For now, we only support the input graph with one node type and one edge type.

933

934
935
936
937
    Parameters
    ----------
    g : DistGraph
        The distributed graph structure.
938
    nodes : tensor or dict
939
940
941
942
        Node ids to sample neighbors from.

    Returns
    -------
943
    DGLGraph
944
        The subgraph.
945
946
947

        One can retrieve the mapping from subgraph edge ID to parent
        edge ID via ``dgl.EID`` edge features of the subgraph.
948
    """
Da Zheng's avatar
Da Zheng committed
949
    if isinstance(nodes, dict):
950
951
952
        assert (
            len(nodes) == 1
        ), "The distributed in_subgraph only supports one node type for now."
Da Zheng's avatar
Da Zheng committed
953
        nodes = list(nodes.values())[0]
954

955
956
    def issue_remote_req(node_ids):
        return InSubgraphRequest(node_ids)
957

958
959
    def local_access(local_g, partition_book, local_nids):
        return _in_subgraph(local_g, partition_book, local_nids)
960

961
962
    return _distributed_access(g, nodes, issue_remote_req, local_access)

963

964
965
966
967
968
969
970
971
def _distributed_get_node_property(g, n, issue_remote_req, local_access):
    req_list = []
    partition_book = g.get_partition_book()
    n = toindex(n).tousertensor()
    partition_id = partition_book.nid2partid(n)
    local_nids = None
    reorder_idx = []
    for pid in range(partition_book.num_partitions()):
972
        mask = partition_id == pid
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
        nid = F.boolean_mask(n, mask)
        reorder_idx.append(F.nonzero_1d(mask))
        if pid == partition_book.partid and g.local_partition is not None:
            assert local_nids is None
            local_nids = nid
        elif len(nid) != 0:
            req = issue_remote_req(nid, pid)
            req_list.append((pid, req))

    # send requests to the remote machine.
    msgseq2pos = None
    if len(req_list) > 0:
        msgseq2pos = send_requests_to_machine(req_list)

    # handle edges in local partition.
    vals = None
    if local_nids is not None:
        local_vals = local_access(g.local_partition, partition_book, local_nids)
        shape = list(F.shape(local_vals))
        shape[0] = len(n)
        vals = F.zeros(shape, F.dtype(local_vals), F.cpu())
994
995
996
        vals = F.scatter_row(
            vals, reorder_idx[partition_book.partid], local_vals
        )
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009

    # receive responses from remote machines.
    if msgseq2pos is not None:
        results = recv_responses(msgseq2pos)
        if len(results) > 0 and vals is None:
            shape = list(F.shape(results[0].val))
            shape[0] = len(n)
            vals = F.zeros(shape, F.dtype(results[0].val), F.cpu())
        for result in results:
            val = result.val
            vals = F.scatter_row(vals, reorder_idx[result.order_id], val)
    return vals

1010

1011
def in_degrees(g, v):
1012
1013
    """Get in-degrees"""

1014
1015
    def issue_remote_req(v, order_id):
        return InDegreeRequest(v, order_id)
1016

1017
1018
    def local_access(local_g, partition_book, v):
        return _in_degrees(local_g, partition_book, v)
1019

1020
1021
    return _distributed_get_node_property(g, v, issue_remote_req, local_access)

1022

1023
def out_degrees(g, u):
1024
1025
    """Get out-degrees"""

1026
1027
    def issue_remote_req(u, order_id):
        return OutDegreeRequest(u, order_id)
1028

1029
1030
    def local_access(local_g, partition_book, u):
        return _out_degrees(local_g, partition_book, u)
1031

1032
1033
    return _distributed_get_node_property(g, u, issue_remote_req, local_access)

1034

1035
register_service(SAMPLING_SERVICE_ID, SamplingRequest, SubgraphResponse)
1036
register_service(EDGES_SERVICE_ID, EdgesRequest, FindEdgeResponse)
1037
register_service(INSUBGRAPH_SERVICE_ID, InSubgraphRequest, SubgraphResponse)
1038
1039
register_service(OUTDEGREE_SERVICE_ID, OutDegreeRequest, OutDegreeResponse)
register_service(INDEGREE_SERVICE_ID, InDegreeRequest, InDegreeResponse)
1040
1041
1042
register_service(
    ETYPE_SAMPLING_SERVICE_ID, SamplingRequestEtype, SubgraphResponse
)