"tutorials/vscode:/vscode.git/clone" did not exist on "acefb9a9f0af08a72d0e25d632e3a0be493e9178"
dist_graph.py 27.2 KB
Newer Older
1
2
3
"""Define distributed graph."""

from collections.abc import MutableMapping
4
import os
5
import numpy as np
6

7
8
from ..heterograph import DGLHeteroGraph
from .. import heterograph_index
9
10
from .. import backend as F
from ..base import NID, EID
Da Zheng's avatar
Da Zheng committed
11
from .kvstore import KVServer, get_kvstore
12
from .standalone_kvstore import KVClient as SA_KVClient
13
14
from .._ffi.ndarray import empty_shared_mem
from ..frame import infer_scheme
15
from .partition import load_partition, load_partition_book
Da Zheng's avatar
Da Zheng committed
16
from .graph_partition_book import PartitionPolicy, get_shared_mem_partition_book
17
from .graph_partition_book import NODE_PART_POLICY, EDGE_PART_POLICY
18
from .shared_mem_utils import _to_shared_mem, _get_ndata_path, _get_edata_path, DTYPE_DICT
19
from . import rpc
Da Zheng's avatar
Da Zheng committed
20
from . import role
21
22
from .server_state import ServerState
from .rpc_server import start_server
23
from .graph_services import find_edges as dist_find_edges
24
from .dist_tensor import DistTensor, _get_data_name
25

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
INIT_GRAPH = 800001

class InitGraphRequest(rpc.Request):
    """ Init graph on the backup servers.

    When the backup server starts, they don't load the graph structure.
    This request tells the backup servers that they can map to the graph structure
    with shared memory.
    """
    def __init__(self, graph_name):
        self._graph_name = graph_name

    def __getstate__(self):
        return self._graph_name

    def __setstate__(self, state):
        self._graph_name = state

    def process_request(self, server_state):
        if server_state.graph is None:
            server_state.graph = _get_graph_from_shared_mem(self._graph_name)
        return InitGraphResponse(self._graph_name)

class InitGraphResponse(rpc.Response):
    """ Ack the init graph request
    """
    def __init__(self, graph_name):
        self._graph_name = graph_name

    def __getstate__(self):
        return self._graph_name

    def __setstate__(self, state):
        self._graph_name = state

61
def _copy_graph_to_shared_mem(g, graph_name):
62
    new_g = g.shared_memory(graph_name, formats='csc')
63
64
    # We should share the node/edge data to the client explicitly instead of putting them
    # in the KVStore because some of the node/edge data may be duplicated.
Da Zheng's avatar
Da Zheng committed
65
66
67
68
    local_node_path = _get_ndata_path(graph_name, 'inner_node')
    new_g.ndata['inner_node'] = _to_shared_mem(g.ndata['inner_node'], local_node_path)
    local_edge_path = _get_edata_path(graph_name, 'inner_edge')
    new_g.edata['inner_edge'] = _to_shared_mem(g.edata['inner_edge'], local_edge_path)
69
70
    new_g.ndata[NID] = _to_shared_mem(g.ndata[NID], _get_ndata_path(graph_name, NID))
    new_g.edata[EID] = _to_shared_mem(g.edata[EID], _get_edata_path(graph_name, EID))
71
72
    return new_g

Da Zheng's avatar
Da Zheng committed
73
74
FIELD_DICT = {'inner_node': F.int64,
              'inner_edge': F.int64,
75
76
77
78
79
              NID: F.int64,
              EID: F.int64}

def _is_ndata_name(name):
    ''' Is this node data in the kvstore '''
80
    return name[:5] == NODE_PART_POLICY + ':'
81
82
83

def _is_edata_name(name):
    ''' Is this edge data in the kvstore '''
84
    return name[:5] == EDGE_PART_POLICY + ':'
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

def _get_shared_mem_ndata(g, graph_name, name):
    ''' Get shared-memory node data from DistGraph server.

    This is called by the DistGraph client to access the node data in the DistGraph server
    with shared memory.
    '''
    shape = (g.number_of_nodes(),)
    dtype = FIELD_DICT[name]
    dtype = DTYPE_DICT[dtype]
    data = empty_shared_mem(_get_ndata_path(graph_name, name), False, shape, dtype)
    dlpack = data.to_dlpack()
    return F.zerocopy_from_dlpack(dlpack)

def _get_shared_mem_edata(g, graph_name, name):
    ''' Get shared-memory edge data from DistGraph server.

    This is called by the DistGraph client to access the edge data in the DistGraph server
    with shared memory.
    '''
    shape = (g.number_of_edges(),)
    dtype = FIELD_DICT[name]
    dtype = DTYPE_DICT[dtype]
    data = empty_shared_mem(_get_edata_path(graph_name, name), False, shape, dtype)
    dlpack = data.to_dlpack()
    return F.zerocopy_from_dlpack(dlpack)

def _get_graph_from_shared_mem(graph_name):
    ''' Get the graph from the DistGraph server.

    The DistGraph server puts the graph structure of the local partition in the shared memory.
    The client can access the graph structure and some metadata on nodes and edges directly
    through shared memory to reduce the overhead of data access.
    '''
119
120
121
122
    g, ntypes, etypes = heterograph_index.create_heterograph_from_shared_memory(graph_name)
    if g is None:
        return None
    g = DGLHeteroGraph(g, ntypes, etypes)
Da Zheng's avatar
Da Zheng committed
123
124
    g.ndata['inner_node'] = _get_shared_mem_ndata(g, graph_name, 'inner_node')
    g.edata['inner_edge'] = _get_shared_mem_edata(g, graph_name, 'inner_edge')
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    g.ndata[NID] = _get_shared_mem_ndata(g, graph_name, NID)
    g.edata[EID] = _get_shared_mem_edata(g, graph_name, EID)
    return g

class NodeDataView(MutableMapping):
    """The data view class when dist_graph.ndata[...].data is called.
    """
    __slots__ = ['_graph', '_data']

    def __init__(self, g):
        self._graph = g
        # When this is created, the server may already load node data. We need to
        # initialize the node data in advance.
        names = g._get_all_ndata_names()
139
140
141
142
143
144
145
        policy = PartitionPolicy(NODE_PART_POLICY, g.get_partition_book())
        self._data = {}
        for name in names:
            name1 = _get_data_name(name, policy.policy_str)
            dtype, shape, _ = g._client.get_data_meta(name1)
            # We create a wrapper on the existing tensor in the kvstore.
            self._data[name] = DistTensor(g, shape, dtype, name, part_policy=policy)
146
147
148
149
150
151
152
153

    def _get_names(self):
        return list(self._data.keys())

    def __getitem__(self, key):
        return self._data[key]

    def __setitem__(self, key, val):
154
        self._data[key] = val
155
156

    def __delitem__(self, key):
157
        del self._data[key]
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184

    def __len__(self):
        # The number of node data may change. Let's count it every time we need them.
        # It's not called frequently. It should be fine.
        return len(self._data)

    def __iter__(self):
        return iter(self._data)

    def __repr__(self):
        reprs = {}
        for name in self._data:
            dtype = F.dtype(self._data[name])
            shape = F.shape(self._data[name])
            reprs[name] = 'DistTensor(shape={}, dtype={})'.format(str(shape), str(dtype))
        return repr(reprs)

class EdgeDataView(MutableMapping):
    """The data view class when G.edges[...].data is called.
    """
    __slots__ = ['_graph', '_data']

    def __init__(self, g):
        self._graph = g
        # When this is created, the server may already load edge data. We need to
        # initialize the edge data in advance.
        names = g._get_all_edata_names()
185
186
187
188
189
190
191
        policy = PartitionPolicy(EDGE_PART_POLICY, g.get_partition_book())
        self._data = {}
        for name in names:
            name1 = _get_data_name(name, policy.policy_str)
            dtype, shape, _ = g._client.get_data_meta(name1)
            # We create a wrapper on the existing tensor in the kvstore.
            self._data[name] = DistTensor(g, shape, dtype, name, part_policy=policy)
192
193
194
195
196
197
198
199

    def _get_names(self):
        return list(self._data.keys())

    def __getitem__(self, key):
        return self._data[key]

    def __setitem__(self, key, val):
200
        self._data[key] = val
201
202

    def __delitem__(self, key):
203
        del self._data[key]
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233

    def __len__(self):
        # The number of edge data may change. Let's count it every time we need them.
        # It's not called frequently. It should be fine.
        return len(self._data)

    def __iter__(self):
        return iter(self._data)

    def __repr__(self):
        reprs = {}
        for name in self._data:
            dtype = F.dtype(self._data[name])
            shape = F.shape(self._data[name])
            reprs[name] = 'DistTensor(shape={}, dtype={})'.format(str(shape), str(dtype))
        return repr(reprs)


class DistGraphServer(KVServer):
    ''' The DistGraph server.

    This DistGraph server loads the graph data and sets up a service so that clients can read data
    of a graph partition (graph structure, node data and edge data) from remote machines.
    A server is responsible for one graph partition.

    Currently, each machine runs only one main server with a set of backup servers to handle
    clients' requests. The main server and the backup servers all handle the requests for the same
    graph partition. They all share the partition data (graph structure and node/edge data) with
    shared memory.

234
235
236
    By default, the partition data is shared with the DistGraph clients that run on
    the same machine. However, a user can disable shared memory option. This is useful for the case
    that a user wants to run the server and the client on different machines.
237
238
239
240
241

    Parameters
    ----------
    server_id : int
        The server ID (start from 0).
242
243
244
    ip_config : str
        Path of IP configuration file.
    num_clients : int
245
        Total number of client nodes.
246
    part_config : string
247
        The path of the config file generated by the partition tool.
248
249
    disable_shared_mem : bool
        Disable shared memory.
250
    '''
251
    def __init__(self, server_id, ip_config, num_clients, part_config, disable_shared_mem=False):
252
253
254
255
        super(DistGraphServer, self).__init__(server_id=server_id, ip_config=ip_config,
                                              num_clients=num_clients)
        self.ip_config = ip_config
        # Load graph partition data.
256
257
258
259
260
261
262
263
264
265
        if self.is_backup_server():
            # The backup server doesn't load the graph partition. It'll initialized afterwards.
            self.gpb, graph_name = load_partition_book(part_config, self.part_id)
            self.client_g = None
        else:
            self.client_g, node_feats, edge_feats, self.gpb, \
                    graph_name = load_partition(part_config, self.part_id)
            print('load ' + graph_name)
            if not disable_shared_mem:
                self.client_g = _copy_graph_to_shared_mem(self.client_g, graph_name)
266

267
268
        if not disable_shared_mem:
            self.gpb.shared_memory(graph_name)
269
        assert self.gpb.partid == self.part_id
270
271
        self.add_part_policy(PartitionPolicy(NODE_PART_POLICY, self.gpb))
        self.add_part_policy(PartitionPolicy(EDGE_PART_POLICY, self.gpb))
272
273

        if not self.is_backup_server():
274
            for name in node_feats:
275
276
                self.init_data(name=_get_data_name(name, NODE_PART_POLICY),
                               policy_str=NODE_PART_POLICY,
277
                               data_tensor=node_feats[name])
278
            for name in edge_feats:
279
280
                self.init_data(name=_get_data_name(name, EDGE_PART_POLICY),
                               policy_str=EDGE_PART_POLICY,
281
282
283
284
285
286
                               data_tensor=edge_feats[name])

    def start(self):
        """ Start graph store server.
        """
        # start server
Jinjing Zhou's avatar
Jinjing Zhou committed
287
        server_state = ServerState(kv_store=self, local_g=self.client_g, partition_book=self.gpb)
288
        print('start graph service on server {} for part {}'.format(self.server_id, self.part_id))
Jinjing Zhou's avatar
Jinjing Zhou committed
289
        start_server(server_id=self.server_id, ip_config=self.ip_config,
290
291
                     num_clients=self.num_clients, server_state=server_state)

292
293
294
295
class DistGraph:
    ''' The DistGraph client.

    This provides the graph interface to access the partitioned graph data for distributed GNN
296
297
    training. All data of partitions are loaded by the DistGraph server.

298
299
300
301
302
303
304
305
306
307
    DistGraph can run in two modes: the standalone mode and the distributed mode.

    * When a user runs the training script normally, DistGraph will be in the standalone mode.
    In this mode, the input graph has to be constructed with only one partition. This mode is
    used for testing and debugging purpose.
    * When a user runs the training script with the distributed launch script, DistGraph will
    be set into the distributed mode. This is used for actual distributed training.

    When running in the distributed mode, `DistGraph` uses shared-memory to access
    the partition data in the local machine.
308
309
310
311
312
    This gives the best performance for distributed training when we run `DistGraphServer`
    and `DistGraph` on the same machine. However, a user may want to run them in separate
    machines. In this case, a user may want to disable shared memory by passing
    `disable_shared_mem=False` when creating `DistGraphServer`. When shared-memory is disabled,
    a user has to pass a partition book.
313
314
315

    Parameters
    ----------
316
317
    ip_config : str
        Path of IP configuration file.
318
319
    graph_name : str
        The name of the graph. This name has to be the same as the one used in DistGraphServer.
320
321
    gpb : PartitionBook
        The partition book object
322
    part_config : str
323
        The partition config file. It's used in the standalone mode.
324
    '''
325
    def __init__(self, ip_config, graph_name, gpb=None, part_config=None):
326
327
328
        self.ip_config = ip_config
        self.graph_name = graph_name
        self._gpb_input = gpb
329
        if os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone':
330
            assert part_config is not None, \
331
332
333
                    'When running in the standalone model, the partition config file is required'
            self._client = SA_KVClient()
            # Load graph partition data.
334
            g, node_feats, edge_feats, self._gpb, _ = load_partition(part_config, 0)
335
336
337
338
            assert self._gpb.num_partitions() == 1, \
                    'The standalone mode can only work with the graph data with one partition'
            if self._gpb is None:
                self._gpb = gpb
339
            self._g = g
340
            for name in node_feats:
341
                self._client.add_data(_get_data_name(name, NODE_PART_POLICY), node_feats[name])
342
            for name in edge_feats:
343
                self._client.add_data(_get_data_name(name, EDGE_PART_POLICY), edge_feats[name])
344
            rpc.set_num_client(1)
345
        else:
346
            self._init()
347
348
349
350
351
            # Tell the backup servers to load the graph structure from shared memory.
            for server_id in range(self._client.num_servers):
                rpc.send_request(server_id, InitGraphRequest(graph_name))
            for server_id in range(self._client.num_servers):
                rpc.recv_response()
352
            self._client.barrier()
353

354
355
356
        self._ndata = NodeDataView(self)
        self._edata = EdgeDataView(self)

Da Zheng's avatar
Da Zheng committed
357
358
359
360
361
362
        self._num_nodes = 0
        self._num_edges = 0
        for part_md in self._gpb.metadata():
            self._num_nodes += int(part_md['num_nodes'])
            self._num_edges += int(part_md['num_edges'])

363
364
365
366
367
368
369
370
371
    def _init(self):
        self._client = get_kvstore()
        self._g = _get_graph_from_shared_mem(self.graph_name)
        self._gpb = get_shared_mem_partition_book(self.graph_name, self._g)
        if self._gpb is None:
            self._gpb = self._gpb_input
        self._client.map_shared_data(self._gpb)

    def __getstate__(self):
372
        return self.ip_config, self.graph_name, self._gpb
373
374
375
376
377
378
379
380
381
382
383
384
385

    def __setstate__(self, state):
        self.ip_config, self.graph_name, self._gpb_input = state
        self._init()

        self._ndata = NodeDataView(self)
        self._edata = EdgeDataView(self)
        self._num_nodes = 0
        self._num_edges = 0
        for part_md in self._gpb.metadata():
            self._num_nodes += int(part_md['num_nodes'])
            self._num_edges += int(part_md['num_edges'])

386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
    @property
    def local_partition(self):
        ''' Return the local partition on the client

        DistGraph provides a global view of the distributed graph. Internally,
        it may contains a partition of the graph if it is co-located with
        the server. If there is no co-location, this returns None.

        Returns
        -------
        DGLHeterograph
            The local partition
        '''
        return self._g

401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
    @property
    def ndata(self):
        """Return the data view of all the nodes.

        Returns
        -------
        NodeDataView
            The data view in the distributed graph storage.
        """
        return self._ndata

    @property
    def edata(self):
        """Return the data view of all the edges.

        Returns
        -------
        EdgeDataView
            The data view in the distributed graph storage.
        """
        return self._edata

Da Zheng's avatar
Da Zheng committed
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
    @property
    def ntypes(self):
        """Return the list of node types of this graph.

        Returns
        -------
        list of str

        Examples
        --------

        >>> g = DistGraph("ip_config.txt", "test")
        >>> g.ntypes
        ['_U']
        """
        # Currently, we only support a graph with one node type.
        return ['_U']

    @property
    def etypes(self):
        """Return the list of edge types of this graph.

        Returns
        -------
        list of str

        Examples
        --------

        >>> g = DistGraph("ip_config.txt", "test")
        >>> g.etypes
        ['_E']
        """
        # Currently, we only support a graph with one edge type.
        return ['_E']

459
460
    def number_of_nodes(self):
        """Return the number of nodes"""
Da Zheng's avatar
Da Zheng committed
461
        return self._num_nodes
462
463
464

    def number_of_edges(self):
        """Return the number of edges"""
Da Zheng's avatar
Da Zheng committed
465
        return self._num_edges
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488

    def node_attr_schemes(self):
        """Return the node feature and embedding schemes."""
        schemes = {}
        for key in self.ndata:
            schemes[key] = infer_scheme(self.ndata[key])
        return schemes

    def edge_attr_schemes(self):
        """Return the edge feature and embedding schemes."""
        schemes = {}
        for key in self.edata:
            schemes[key] = infer_scheme(self.edata[key])
        return schemes

    def rank(self):
        ''' The rank of the distributed graph store.

        Returns
        -------
        int
            The rank of the current graph store.
        '''
Da Zheng's avatar
Da Zheng committed
489
        return role.get_global_rank()
490

491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
    def find_edges(self, edges):
        """ Given an edge ID array, return the source
        and destination node ID array ``s`` and ``d``.  ``s[i]`` and ``d[i]``
        are source and destination node ID for edge ``eid[i]``.

        Parameters
        ----------
        edges : tensor
            The edge ID array.

        Returns
        -------
        tensor
            The source node ID array.
        tensor
            The destination node ID array.
        """
        return dist_find_edges(self, edges)

510
511
512
513
514
515
516
517
518
    def get_partition_book(self):
        """Get the partition information.

        Returns
        -------
        GraphPartitionBook
            Object that stores all kinds of partition information.
        """
        return self._gpb
519

520
521
522
523
524
525
526
    def barrier(self):
        '''Barrier for all client nodes.

        This API will be blocked untill all the clients invoke this API.
        '''
        self._client.barrier()

527
528
529
    def _get_all_ndata_names(self):
        ''' Get the names of all node data.
        '''
530
        names = self._client.data_name_list()
531
532
533
534
535
536
537
538
539
540
        ndata_names = []
        for name in names:
            if _is_ndata_name(name):
                # Remove the prefix "node:"
                ndata_names.append(name[5:])
        return ndata_names

    def _get_all_edata_names(self):
        ''' Get the names of all edge data.
        '''
541
        names = self._client.data_name_list()
542
543
544
545
546
547
        edata_names = []
        for name in names:
            if _is_edata_name(name):
                # Remove the prefix "edge:"
                edata_names.append(name[5:])
        return edata_names
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571

def _get_overlap(mask_arr, ids):
    """ Select the Ids given a boolean mask array.

    The boolean mask array indicates all of the Ids to be selected. We want to
    find the overlap between the Ids selected by the boolean mask array and
    the Id array.

    Parameters
    ----------
    mask_arr : 1D tensor
        A boolean mask array.
    ids : 1D tensor
        A vector with Ids.

    Returns
    -------
    1D tensor
        The selected Ids.
    """
    if isinstance(mask_arr, DistTensor):
        masks = mask_arr[ids]
        return F.boolean_mask(ids, masks)
    else:
572
        masks = F.gather_row(F.tensor(mask_arr), ids)
573
574
        return F.boolean_mask(ids, masks)

575
576
577
def _split_local(partition_book, rank, elements, local_eles):
    ''' Split the input element list with respect to data locality.
    '''
Da Zheng's avatar
Da Zheng committed
578
    num_clients = role.get_num_trainers()
579
580
    num_client_per_part = num_clients // partition_book.num_partitions()
    if rank is None:
Da Zheng's avatar
Da Zheng committed
581
582
583
        rank = role.get_trainer_rank()
    assert rank < num_clients, \
            'The input rank ({}) is incorrect. #Trainers: {}'.format(rank, num_clients)
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
    # all ranks of the clients in the same machine are in a contiguous range.
    client_id_in_part = rank  % num_client_per_part
    local_eles = _get_overlap(elements, local_eles)

    # get a subset for the local client.
    size = len(local_eles) // num_client_per_part
    # if this isn't the last client in the partition.
    if client_id_in_part + 1 < num_client_per_part:
        return local_eles[(size * client_id_in_part):(size * (client_id_in_part + 1))]
    else:
        return local_eles[(size * client_id_in_part):]

def _split_even(partition_book, rank, elements):
    ''' Split the input element list evenly.
    '''
Da Zheng's avatar
Da Zheng committed
599
    num_clients = role.get_num_trainers()
600
601
    num_client_per_part = num_clients // partition_book.num_partitions()
    # all ranks of the clients in the same machine are in a contiguous range.
Da Zheng's avatar
Da Zheng committed
602
603
604
605
606
    if rank is None:
        rank = role.get_trainer_rank()
    assert rank < num_clients, \
            'The input rank ({}) is incorrect. #Trainers: {}'.format(rank, num_clients)
    # This conversion of rank is to make the new rank aligned with partitioning.
607
608
609
610
611
612
613
614
    client_id_in_part = rank  % num_client_per_part
    rank = client_id_in_part + num_client_per_part * partition_book.partid

    if isinstance(elements, DistTensor):
        # Here we need to fetch all elements from the kvstore server.
        # I hope it's OK.
        eles = F.nonzero_1d(elements[0:len(elements)])
    else:
615
        eles = F.nonzero_1d(F.tensor(elements))
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641

    # here we divide the element list as evenly as possible. If we use range partitioning,
    # the split results also respect the data locality. Range partitioning is the default
    # strategy.
    # TODO(zhegnda) we need another way to divide the list for other partitioning strategy.

    # compute the offset of each split and ensure that the difference of each partition size
    # is 1.
    part_size = len(eles) // num_clients
    sizes = [part_size] * num_clients
    remain = len(eles) - part_size * num_clients
    if remain > 0:
        for i in range(num_clients):
            sizes[i] += 1
            remain -= 1
            if remain == 0:
                break
    offsets = np.cumsum(sizes)
    assert offsets[-1] == len(eles)

    if rank == 0:
        return eles[0:offsets[0]]
    else:
        return eles[offsets[rank-1]:offsets[rank]]


Da Zheng's avatar
Da Zheng committed
642
def node_split(nodes, partition_book=None, rank=None, force_even=True):
643
644
645
646
647
648
649
650
651
652
    ''' Split nodes and return a subset for the local rank.

    This function splits the input nodes based on the partition book and
    returns a subset of nodes for the local rank. This method is used for
    dividing workloads for distributed training.

    The input nodes can be stored as a vector of masks. The length of the vector is
    the same as the number of nodes in a graph; 1 indicates that the vertex in
    the corresponding location exists.

653
654
655
656
657
658
659
660
    There are two strategies to split the nodes. By default, it splits the nodes
    in a way to maximize data locality. That is, all nodes that belong to a process
    are returned. If `force_even` is set to true, the nodes are split evenly so
    that each process gets almost the same number of nodes. The current implementation
    can still enable data locality when a graph is partitioned with range partitioning.
    In this case, majority of the nodes returned for a process are the ones that
    belong to the process. If range partitioning is not used, data locality isn't guaranteed.

661
662
663
664
665
666
667
    Parameters
    ----------
    nodes : 1D tensor or DistTensor
        A boolean mask vector that indicates input nodes.
    partition_book : GraphPartitionBook
        The graph partition book
    rank : int
668
669
670
        The rank of a process. If not given, the rank of the current process is used.
    force_even : bool
        Force the nodes are split evenly.
671
672
673
674
675
676
677

    Returns
    -------
    1D-tensor
        The vector of node Ids that belong to the rank.
    '''
    num_nodes = 0
Da Zheng's avatar
Da Zheng committed
678
679
680
681
    if not isinstance(nodes, DistTensor):
        assert partition_book is not None, 'Regular tensor requires a partition book.'
    elif partition_book is None:
        partition_book = nodes.part_policy.partition_book
682
683
684
685
    for part in partition_book.metadata():
        num_nodes += part['num_nodes']
    assert len(nodes) == num_nodes, \
            'The length of boolean mask vector should be the number of nodes in the graph.'
686
687
688
689
690
691
    if force_even:
        return _split_even(partition_book, rank, nodes)
    else:
        # Get all nodes that belong to the rank.
        local_nids = partition_book.partid2nids(partition_book.partid)
        return _split_local(partition_book, rank, nodes, local_nids)
692

Da Zheng's avatar
Da Zheng committed
693
def edge_split(edges, partition_book=None, rank=None, force_even=True):
694
695
696
697
698
699
700
701
702
703
    ''' Split edges and return a subset for the local rank.

    This function splits the input edges based on the partition book and
    returns a subset of edges for the local rank. This method is used for
    dividing workloads for distributed training.

    The input edges can be stored as a vector of masks. The length of the vector is
    the same as the number of edges in a graph; 1 indicates that the edge in
    the corresponding location exists.

704
705
706
707
708
709
710
711
    There are two strategies to split the edges. By default, it splits the edges
    in a way to maximize data locality. That is, all edges that belong to a process
    are returned. If `force_even` is set to true, the edges are split evenly so
    that each process gets almost the same number of edges. The current implementation
    can still enable data locality when a graph is partitioned with range partitioning.
    In this case, majority of the edges returned for a process are the ones that
    belong to the process. If range partitioning is not used, data locality isn't guaranteed.

712
713
714
    Parameters
    ----------
    edges : 1D tensor or DistTensor
715
        A boolean mask vector that indicates input edges.
716
717
718
    partition_book : GraphPartitionBook
        The graph partition book
    rank : int
719
720
721
        The rank of a process. If not given, the rank of the current process is used.
    force_even : bool
        Force the edges are split evenly.
722
723
724
725
726
727
728

    Returns
    -------
    1D-tensor
        The vector of edge Ids that belong to the rank.
    '''
    num_edges = 0
Da Zheng's avatar
Da Zheng committed
729
730
731
732
    if not isinstance(edges, DistTensor):
        assert partition_book is not None, 'Regular tensor requires a partition book.'
    elif partition_book is None:
        partition_book = edges.part_policy.partition_book
733
734
735
736
    for part in partition_book.metadata():
        num_edges += part['num_edges']
    assert len(edges) == num_edges, \
            'The length of boolean mask vector should be the number of edges in the graph.'
737
738
739
740
741
742
743

    if force_even:
        return _split_even(partition_book, rank, edges)
    else:
        # Get all edges that belong to the rank.
        local_eids = partition_book.partid2eids(partition_book.partid)
        return _split_local(partition_book, rank, edges, local_eids)
744
745

rpc.register_service(INIT_GRAPH, InitGraphRequest, InitGraphResponse)