dist_graph.py 60.5 KB
Newer Older
1
2
3
"""Define distributed graph."""

from collections.abc import MutableMapping
4
5
from collections import namedtuple

6
import os
7
import numpy as np
8

9
from ..heterograph import DGLHeteroGraph
10
11
from ..convert import heterograph as dgl_heterograph
from ..convert import graph as dgl_graph
12
from ..transforms import compact_graphs
13
from .. import heterograph_index
14
from .. import backend as F
15
from ..base import NID, EID, NTYPE, ETYPE, ALL, is_all
Da Zheng's avatar
Da Zheng committed
16
from .kvstore import KVServer, get_kvstore
17
from .._ffi.ndarray import empty_shared_mem
18
from ..ndarray import exist_shared_mem_array
19
from ..frame import infer_scheme
20
from .partition import load_partition, load_partition_book
Da Zheng's avatar
Da Zheng committed
21
from .graph_partition_book import PartitionPolicy, get_shared_mem_partition_book
22
23
from .graph_partition_book import HeteroDataName, parse_hetero_data_name
from .graph_partition_book import NodePartitionPolicy, EdgePartitionPolicy
24
from .shared_mem_utils import _to_shared_mem, _get_ndata_path, _get_edata_path, DTYPE_DICT
25
from . import rpc
Da Zheng's avatar
Da Zheng committed
26
from . import role
27
28
from .server_state import ServerState
from .rpc_server import start_server
29
from . import graph_services
30
from .graph_services import find_edges as dist_find_edges
31
32
from .graph_services import out_degrees as dist_out_degrees
from .graph_services import in_degrees as dist_in_degrees
33
from .dist_tensor import DistTensor
34

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
INIT_GRAPH = 800001

class InitGraphRequest(rpc.Request):
    """ Init graph on the backup servers.

    When the backup server starts, they don't load the graph structure.
    This request tells the backup servers that they can map to the graph structure
    with shared memory.
    """
    def __init__(self, graph_name):
        self._graph_name = graph_name

    def __getstate__(self):
        return self._graph_name

    def __setstate__(self, state):
        self._graph_name = state

    def process_request(self, server_state):
        if server_state.graph is None:
            server_state.graph = _get_graph_from_shared_mem(self._graph_name)
        return InitGraphResponse(self._graph_name)

class InitGraphResponse(rpc.Response):
    """ Ack the init graph request
    """
    def __init__(self, graph_name):
        self._graph_name = graph_name

    def __getstate__(self):
        return self._graph_name

    def __setstate__(self, state):
        self._graph_name = state

70
71
def _copy_graph_to_shared_mem(g, graph_name, graph_format):
    new_g = g.shared_memory(graph_name, formats=graph_format)
72
73
    # We should share the node/edge data to the client explicitly instead of putting them
    # in the KVStore because some of the node/edge data may be duplicated.
74
75
    new_g.ndata['inner_node'] = _to_shared_mem(g.ndata['inner_node'],
                                               _get_ndata_path(graph_name, 'inner_node'))
76
    new_g.ndata[NID] = _to_shared_mem(g.ndata[NID], _get_ndata_path(graph_name, NID))
77
78
79

    new_g.edata['inner_edge'] = _to_shared_mem(g.edata['inner_edge'],
                                               _get_edata_path(graph_name, 'inner_edge'))
80
    new_g.edata[EID] = _to_shared_mem(g.edata[EID], _get_edata_path(graph_name, EID))
81
82
83
84
    # for heterogeneous graph, we need to put ETYPE into KVStore
    # for homogeneous graph, ETYPE does not exist
    if ETYPE in g.edata:
        new_g.edata[ETYPE] = _to_shared_mem(g.edata[ETYPE], _get_edata_path(graph_name, ETYPE))
85
86
    return new_g

87
88
FIELD_DICT = {'inner_node': F.int32,    # A flag indicates whether the node is inside a partition.
              'inner_edge': F.int32,    # A flag indicates whether the edge is inside a partition.
89
              NID: F.int64,
90
              EID: F.int64,
91
92
              NTYPE: F.int32,
              ETYPE: F.int32}
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119

def _get_shared_mem_ndata(g, graph_name, name):
    ''' Get shared-memory node data from DistGraph server.

    This is called by the DistGraph client to access the node data in the DistGraph server
    with shared memory.
    '''
    shape = (g.number_of_nodes(),)
    dtype = FIELD_DICT[name]
    dtype = DTYPE_DICT[dtype]
    data = empty_shared_mem(_get_ndata_path(graph_name, name), False, shape, dtype)
    dlpack = data.to_dlpack()
    return F.zerocopy_from_dlpack(dlpack)

def _get_shared_mem_edata(g, graph_name, name):
    ''' Get shared-memory edge data from DistGraph server.

    This is called by the DistGraph client to access the edge data in the DistGraph server
    with shared memory.
    '''
    shape = (g.number_of_edges(),)
    dtype = FIELD_DICT[name]
    dtype = DTYPE_DICT[dtype]
    data = empty_shared_mem(_get_edata_path(graph_name, name), False, shape, dtype)
    dlpack = data.to_dlpack()
    return F.zerocopy_from_dlpack(dlpack)

120
121
122
def _exist_shared_mem_array(graph_name, name):
    return exist_shared_mem_array(_get_edata_path(graph_name, name))

123
124
125
126
127
128
129
def _get_graph_from_shared_mem(graph_name):
    ''' Get the graph from the DistGraph server.

    The DistGraph server puts the graph structure of the local partition in the shared memory.
    The client can access the graph structure and some metadata on nodes and edges directly
    through shared memory to reduce the overhead of data access.
    '''
130
131
132
133
    g, ntypes, etypes = heterograph_index.create_heterograph_from_shared_memory(graph_name)
    if g is None:
        return None
    g = DGLHeteroGraph(g, ntypes, etypes)
134

Da Zheng's avatar
Da Zheng committed
135
    g.ndata['inner_node'] = _get_shared_mem_ndata(g, graph_name, 'inner_node')
136
    g.ndata[NID] = _get_shared_mem_ndata(g, graph_name, NID)
137
138

    g.edata['inner_edge'] = _get_shared_mem_edata(g, graph_name, 'inner_edge')
139
    g.edata[EID] = _get_shared_mem_edata(g, graph_name, EID)
140
141
142
143

    # heterogeneous graph has ETYPE
    if _exist_shared_mem_array(graph_name, ETYPE):
        g.edata[ETYPE] = _get_shared_mem_edata(g, graph_name, ETYPE)
144
145
    return g

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
NodeSpace = namedtuple('NodeSpace', ['data'])
EdgeSpace = namedtuple('EdgeSpace', ['data'])

class HeteroNodeView(object):
    """A NodeView class to act as G.nodes for a DistGraph."""
    __slots__ = ['_graph']

    def __init__(self, graph):
        self._graph = graph

    def __getitem__(self, key):
        assert isinstance(key, str)
        return NodeSpace(data=NodeDataView(self._graph, key))

class HeteroEdgeView(object):
    """A NodeView class to act as G.nodes for a DistGraph."""
    __slots__ = ['_graph']

    def __init__(self, graph):
        self._graph = graph

    def __getitem__(self, key):
        assert isinstance(key, str)
        return EdgeSpace(data=EdgeDataView(self._graph, key))

171
172
173
174
175
class NodeDataView(MutableMapping):
    """The data view class when dist_graph.ndata[...].data is called.
    """
    __slots__ = ['_graph', '_data']

176
    def __init__(self, g, ntype=None):
177
178
179
        self._graph = g
        # When this is created, the server may already load node data. We need to
        # initialize the node data in advance.
180
181
182
183
184
185
186
187
188
        names = g._get_ndata_names(ntype)
        if ntype is None:
            self._data = g._ndata_store
        else:
            if ntype in g._ndata_store:
                self._data = g._ndata_store[ntype]
            else:
                self._data = {}
                g._ndata_store[ntype] = self._data
189
        for name in names:
190
191
192
            assert name.is_node()
            policy = PartitionPolicy(name.policy_str, g.get_partition_book())
            dtype, shape, _ = g._client.get_data_meta(str(name))
193
            # We create a wrapper on the existing tensor in the kvstore.
194
            self._data[name.get_name()] = DistTensor(shape, dtype, name.get_name(),
195
                                                     part_policy=policy, attach=False)
196
197
198
199
200
201
202
203

    def _get_names(self):
        return list(self._data.keys())

    def __getitem__(self, key):
        return self._data[key]

    def __setitem__(self, key, val):
204
        self._data[key] = val
205
206

    def __delitem__(self, key):
207
        del self._data[key]
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229

    def __len__(self):
        # The number of node data may change. Let's count it every time we need them.
        # It's not called frequently. It should be fine.
        return len(self._data)

    def __iter__(self):
        return iter(self._data)

    def __repr__(self):
        reprs = {}
        for name in self._data:
            dtype = F.dtype(self._data[name])
            shape = F.shape(self._data[name])
            reprs[name] = 'DistTensor(shape={}, dtype={})'.format(str(shape), str(dtype))
        return repr(reprs)

class EdgeDataView(MutableMapping):
    """The data view class when G.edges[...].data is called.
    """
    __slots__ = ['_graph', '_data']

230
    def __init__(self, g, etype=None):
231
232
233
        self._graph = g
        # When this is created, the server may already load edge data. We need to
        # initialize the edge data in advance.
234
235
236
237
238
239
240
241
242
        names = g._get_edata_names(etype)
        if etype is None:
            self._data = g._edata_store
        else:
            if etype in g._edata_store:
                self._data = g._edata_store[etype]
            else:
                self._data = {}
                g._edata_store[etype] = self._data
243
        for name in names:
244
245
246
            assert name.is_edge()
            policy = PartitionPolicy(name.policy_str, g.get_partition_book())
            dtype, shape, _ = g._client.get_data_meta(str(name))
247
            # We create a wrapper on the existing tensor in the kvstore.
248
            self._data[name.get_name()] = DistTensor(shape, dtype, name.get_name(),
249
                                                     part_policy=policy, attach=False)
250
251
252
253
254
255
256
257

    def _get_names(self):
        return list(self._data.keys())

    def __getitem__(self, key):
        return self._data[key]

    def __setitem__(self, key, val):
258
        self._data[key] = val
259
260

    def __delitem__(self, key):
261
        del self._data[key]
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282

    def __len__(self):
        # The number of edge data may change. Let's count it every time we need them.
        # It's not called frequently. It should be fine.
        return len(self._data)

    def __iter__(self):
        return iter(self._data)

    def __repr__(self):
        reprs = {}
        for name in self._data:
            dtype = F.dtype(self._data[name])
            shape = F.shape(self._data[name])
            reprs[name] = 'DistTensor(shape={}, dtype={})'.format(str(shape), str(dtype))
        return repr(reprs)


class DistGraphServer(KVServer):
    ''' The DistGraph server.

283
284
285
    This DistGraph server loads the graph data and sets up a service so that trainers and
    samplers can read data of a graph partition (graph structure, node data and edge data)
    from remote machines. A server is responsible for one graph partition.
286
287
288
289
290
291

    Currently, each machine runs only one main server with a set of backup servers to handle
    clients' requests. The main server and the backup servers all handle the requests for the same
    graph partition. They all share the partition data (graph structure and node/edge data) with
    shared memory.

292
293
294
    By default, the partition data is shared with the DistGraph clients that run on
    the same machine. However, a user can disable shared memory option. This is useful for the case
    that a user wants to run the server and the client on different machines.
295
296
297
298
299

    Parameters
    ----------
    server_id : int
        The server ID (start from 0).
300
301
    ip_config : str
        Path of IP configuration file.
302
303
    num_servers : int
        Server count on each machine.
304
    num_clients : int
305
        Total number of client nodes.
306
    part_config : string
307
        The path of the config file generated by the partition tool.
308
309
    disable_shared_mem : bool
        Disable shared memory.
310
311
    graph_format : str or list of str
        The graph formats.
312
313
    keep_alive : bool
        Whether to keep server alive when clients exit
314
315
    net_type : str
        Backend rpc type: ``'socket'`` or ``'tensorpipe'``
316
    '''
317
    def __init__(self, server_id, ip_config, num_servers,
318
                 num_clients, part_config, disable_shared_mem=False,
319
320
                 graph_format=('csc', 'coo'), keep_alive=False,
                 net_type='tensorpipe'):
321
322
323
        super(DistGraphServer, self).__init__(server_id=server_id,
                                              ip_config=ip_config,
                                              num_servers=num_servers,
324
325
                                              num_clients=num_clients)
        self.ip_config = ip_config
326
        self.num_servers = num_servers
327
        self.keep_alive = keep_alive
328
        self.net_type = net_type
329
        # Load graph partition data.
330
331
        if self.is_backup_server():
            # The backup server doesn't load the graph partition. It'll initialized afterwards.
332
            self.gpb, graph_name, ntypes, etypes = load_partition_book(part_config, self.part_id)
333
334
            self.client_g = None
        else:
335
336
            self.client_g, node_feats, edge_feats, self.gpb, graph_name, \
                    ntypes, etypes = load_partition(part_config, self.part_id)
337
            print('load ' + graph_name)
338
339
340
            # Create the graph formats specified the users.
            self.client_g = self.client_g.formats(graph_format)
            self.client_g.create_formats_()
341
            if not disable_shared_mem:
342
                self.client_g = _copy_graph_to_shared_mem(self.client_g, graph_name, graph_format)
343

344
345
        if not disable_shared_mem:
            self.gpb.shared_memory(graph_name)
346
        assert self.gpb.partid == self.part_id
347
348
349
350
351
352
        for ntype in ntypes:
            node_name = HeteroDataName(True, ntype, None)
            self.add_part_policy(PartitionPolicy(node_name.policy_str, self.gpb))
        for etype in etypes:
            edge_name = HeteroDataName(False, etype, None)
            self.add_part_policy(PartitionPolicy(edge_name.policy_str, self.gpb))
353
354

        if not self.is_backup_server():
355
            for name in node_feats:
356
357
358
359
360
                # The feature name has the following format: node_type + "/" + feature_name to avoid
                # feature name collision for different node types.
                ntype, feat_name = name.split('/')
                data_name = HeteroDataName(True, ntype, feat_name)
                self.init_data(name=str(data_name), policy_str=data_name.policy_str,
361
                               data_tensor=node_feats[name])
362
                self.orig_data.add(str(data_name))
363
            for name in edge_feats:
364
365
366
367
368
                # The feature name has the following format: edge_type + "/" + feature_name to avoid
                # feature name collision for different edge types.
                etype, feat_name = name.split('/')
                data_name = HeteroDataName(False, etype, feat_name)
                self.init_data(name=str(data_name), policy_str=data_name.policy_str,
369
                               data_tensor=edge_feats[name])
370
                self.orig_data.add(str(data_name))
371
372
373
374
375

    def start(self):
        """ Start graph store server.
        """
        # start server
376
377
378
379
        server_state = ServerState(kv_store=self, local_g=self.client_g,
                                   partition_book=self.gpb, keep_alive=self.keep_alive)
        print('start graph service on server {} for part {}'.format(
            self.server_id, self.part_id))
380
381
382
        start_server(server_id=self.server_id,
                     ip_config=self.ip_config,
                     num_servers=self.num_servers,
383
384
385
                     num_clients=self.num_clients,
                     server_state=server_state,
                     net_type=self.net_type)
386

387
class DistGraph:
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
    '''The class for accessing a distributed graph.

    This class provides a subset of DGLGraph APIs for accessing partitioned graph data in
    distributed GNN training and inference. Thus, its main use case is to work with
    distributed sampling APIs to generate mini-batches and perform forward and
    backward computation on the mini-batches.

    The class can run in two modes: the standalone mode and the distributed mode.

    * When a user runs the training script normally, ``DistGraph`` will be in the standalone mode.
      In this mode, the input data must be constructed by
      :py:meth:`~dgl.distributed.partition.partition_graph` with only one partition. This mode is
      used for testing and debugging purpose. In this mode, users have to provide ``part_config``
      so that ``DistGraph`` can load the input graph.
    * When a user runs the training script with the distributed launch script, ``DistGraph`` will
      be set into the distributed mode. This is used for actual distributed training. All data of
      partitions are loaded by the ``DistGraph`` servers, which are created by DGL's launch script.
      ``DistGraph`` connects with the servers to access the partitioned graph data.

    Currently, the ``DistGraph`` servers and clients run on the same set of machines
    in the distributed mode. ``DistGraph`` uses shared-memory to access the partition data
    in the local machine. This gives the best performance for distributed training

    Users may want to run ``DistGraph`` servers and clients on separate sets of machines.
    In this case, a user may want to disable shared memory by passing
    ``disable_shared_mem=False`` when creating ``DistGraphServer``. When shared memory is disabled,
414
    a user has to pass a partition book.
415
416
417
418

    Parameters
    ----------
    graph_name : str
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
        The name of the graph. This name has to be the same as the one used for
        partitioning a graph in :py:meth:`dgl.distributed.partition.partition_graph`.
    gpb : GraphPartitionBook, optional
        The partition book object. Normally, users do not need to provide the partition book.
        This argument is necessary only when users want to run server process and trainer
        processes on different machines.
    part_config : str, optional
        The path of partition configuration file generated by
        :py:meth:`dgl.distributed.partition.partition_graph`. It's used in the standalone mode.

    Examples
    --------
    The example shows the creation of ``DistGraph`` in the standalone mode.

    >>> dgl.distributed.partition_graph(g, 'graph_name', 1, num_hops=1, part_method='metis',
434
    ...                                 out_path='output/', reshuffle=True)
435
436
437
438
439
440
441
442
443
444
445
446
447
    >>> g = dgl.distributed.DistGraph('graph_name', part_config='output/graph_name.json')

    The example shows the creation of ``DistGraph`` in the distributed mode.

    >>> g = dgl.distributed.DistGraph('graph-name')

    The code below shows the mini-batch training using ``DistGraph``.

    >>> def sample(seeds):
    ...     seeds = th.LongTensor(np.asarray(seeds))
    ...     frontier = dgl.distributed.sample_neighbors(g, seeds, 10)
    ...     return dgl.to_block(frontier, seeds)
    >>> dataloader = dgl.distributed.DistDataLoader(dataset=nodes, batch_size=1000,
448
    ...                                             collate_fn=sample, shuffle=True)
449
450
451
452
453
454
455
456
457
458
    >>> for block in dataloader:
    ...     feat = g.ndata['features'][block.srcdata[dgl.NID]]
    ...     labels = g.ndata['labels'][block.dstdata[dgl.NID]]
    ...     pred = model(block, feat)

    Note
    ----
    DGL's distributed training by default runs server processes and trainer processes on the same
    set of machines. If users need to run them on different sets of machines, it requires
    manually setting up servers and trainers. The setup is not fully tested yet.
459
    '''
460
    def __init__(self, graph_name, gpb=None, part_config=None):
461
462
        self.graph_name = graph_name
        self._gpb_input = gpb
463
        if os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone':
464
            assert part_config is not None, \
465
                    'When running in the standalone model, the partition config file is required'
466
            self._client = get_kvstore()
467
468
            assert self._client is not None, \
                    'Distributed module is not initialized. Please call dgl.distributed.initialize.'
469
            # Load graph partition data.
470
            g, node_feats, edge_feats, self._gpb, _, _, _ = load_partition(part_config, 0)
471
472
473
474
            assert self._gpb.num_partitions() == 1, \
                    'The standalone mode can only work with the graph data with one partition'
            if self._gpb is None:
                self._gpb = gpb
475
            self._g = g
476
            for name in node_feats:
477
478
479
480
481
                # The feature name has the following format: node_type + "/" + feature_name.
                ntype, feat_name = name.split('/')
                self._client.add_data(str(HeteroDataName(True, ntype, feat_name)),
                                      node_feats[name],
                                      NodePartitionPolicy(self._gpb, ntype=ntype))
482
            for name in edge_feats:
483
484
485
486
487
                # The feature name has the following format: edge_type + "/" + feature_name.
                etype, feat_name = name.split('/')
                self._client.add_data(str(HeteroDataName(False, etype, feat_name)),
                                      edge_feats[name],
                                      EdgePartitionPolicy(self._gpb, etype=etype))
488
            self._client.map_shared_data(self._gpb)
489
            rpc.set_num_client(1)
490
        else:
491
            self._init()
492
493
494
495
496
            # Tell the backup servers to load the graph structure from shared memory.
            for server_id in range(self._client.num_servers):
                rpc.send_request(server_id, InitGraphRequest(graph_name))
            for server_id in range(self._client.num_servers):
                rpc.recv_response()
497
            self._client.barrier()
498

499
500
        self._ndata_store = {}
        self._edata_store = {}
501
502
503
        self._ndata = NodeDataView(self)
        self._edata = EdgeDataView(self)

Da Zheng's avatar
Da Zheng committed
504
505
506
507
508
509
        self._num_nodes = 0
        self._num_edges = 0
        for part_md in self._gpb.metadata():
            self._num_nodes += int(part_md['num_nodes'])
            self._num_edges += int(part_md['num_edges'])

510
511
512
513
        # When we store node/edge types in a list, they are stored in the order of type IDs.
        self._ntype_map = {ntype:i for i, ntype in enumerate(self.ntypes)}
        self._etype_map = {etype:i for i, etype in enumerate(self.etypes)}

514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
        # Get canonical edge types.
        # TODO(zhengda) this requires the server to store the graph with coo format.
        eid = []
        for etype in self.etypes:
            type_eid = F.zeros((1,), F.int64, F.cpu())
            eid.append(self._gpb.map_to_homo_eid(type_eid, etype))
        eid = F.cat(eid, 0)
        src, dst = dist_find_edges(self, eid)
        src_tids, _ = self._gpb.map_to_per_ntype(src)
        dst_tids, _ = self._gpb.map_to_per_ntype(dst)
        self._canonical_etypes = []
        etype_ids = F.arange(0, len(self.etypes))
        for src_tid, etype_id, dst_tid in zip(src_tids, etype_ids, dst_tids):
            src_tid = F.as_scalar(src_tid)
            etype_id = F.as_scalar(etype_id)
            dst_tid = F.as_scalar(dst_tid)
            self._canonical_etypes.append((self.ntypes[src_tid], self.etypes[etype_id],
                                           self.ntypes[dst_tid]))
532
533
534
535
536
537
        self._etype2canonical = {}
        for src_type, etype, dst_type in self._canonical_etypes:
            if etype in self._etype2canonical:
                self._etype2canonical[etype] = ()
            else:
                self._etype2canonical[etype] = (src_type, etype, dst_type)
538

539
540
    def _init(self):
        self._client = get_kvstore()
541
542
        assert self._client is not None, \
                'Distributed module is not initialized. Please call dgl.distributed.initialize.'
543
544
545
546
547
548
549
        self._g = _get_graph_from_shared_mem(self.graph_name)
        self._gpb = get_shared_mem_partition_book(self.graph_name, self._g)
        if self._gpb is None:
            self._gpb = self._gpb_input
        self._client.map_shared_data(self._gpb)

    def __getstate__(self):
550
        return self.graph_name, self._gpb, self._canonical_etypes
551
552

    def __setstate__(self, state):
553
        self.graph_name, self._gpb_input, self._canonical_etypes = state
554
555
        self._init()

556
557
558
559
560
561
        self._etype2canonical = {}
        for src_type, etype, dst_type in self._canonical_etypes:
            if etype in self._etype2canonical:
                self._etype2canonical[etype] = ()
            else:
                self._etype2canonical[etype] = (src_type, etype, dst_type)
562
563
        self._ndata_store = {}
        self._edata_store = {}
564
565
566
567
568
569
570
571
        self._ndata = NodeDataView(self)
        self._edata = EdgeDataView(self)
        self._num_nodes = 0
        self._num_edges = 0
        for part_md in self._gpb.metadata():
            self._num_nodes += int(part_md['num_nodes'])
            self._num_edges += int(part_md['num_edges'])

572
573
574
575
576
577
    @property
    def local_partition(self):
        ''' Return the local partition on the client

        DistGraph provides a global view of the distributed graph. Internally,
        it may contains a partition of the graph if it is co-located with
578
579
        the server. When servers and clients run on separate sets of machines,
        this returns None.
580
581
582

        Returns
        -------
583
        DGLGraph
584
585
586
587
            The local partition
        '''
        return self._g

588
589
590
591
592
593
594
595
596
597
598
599
    @property
    def nodes(self):
        '''Return a node view
        '''
        return HeteroNodeView(self)

    @property
    def edges(self):
        '''Return an edge view
        '''
        return HeteroEdgeView(self)

600
601
602
603
604
605
606
607
608
    @property
    def ndata(self):
        """Return the data view of all the nodes.

        Returns
        -------
        NodeDataView
            The data view in the distributed graph storage.
        """
609
        assert len(self.ntypes) == 1, "ndata only works for a graph with one node type."
610
611
612
613
614
615
616
617
618
619
620
        return self._ndata

    @property
    def edata(self):
        """Return the data view of all the edges.

        Returns
        -------
        EdgeDataView
            The data view in the distributed graph storage.
        """
621
        assert len(self.etypes) == 1, "edata only works for a graph with one edge type."
622
623
        return self._edata

624
625
626
627
628
629
630
631
632
633
634
635
636
637
    @property
    def idtype(self):
        """The dtype of graph index

        Returns
        -------
        backend dtype object
            th.int32/th.int64 or tf.int32/tf.int64 etc.

        See Also
        --------
        long
        int
        """
638
        # TODO(da?): describe when self._g is None and idtype shouldn't be called.
639
        return F.int64
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659

    @property
    def device(self):
        """Get the device context of this graph.

        Examples
        --------
        The following example uses PyTorch backend.

        >>> g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
        >>> print(g.device)
        device(type='cpu')
        >>> g = g.to('cuda:0')
        >>> print(g.device)
        device(type='cuda', index=0)

        Returns
        -------
        Device context object
        """
660
        # TODO(da?): describe when self._g is None and device shouldn't be called.
661
        return F.cpu()
662

663
664
665
666
667
668
669
670
671
672
673
    def is_pinned(self):
        """Check if the graph structure is pinned to the page-locked memory.

        Returns
        -------
        bool
            True if the graph structure is pinned.
        """
        # (Xin Yao): Currently we don't support pinning a DistGraph.
        return False

Da Zheng's avatar
Da Zheng committed
674
675
676
677
678
679
680
681
682
683
684
    @property
    def ntypes(self):
        """Return the list of node types of this graph.

        Returns
        -------
        list of str

        Examples
        --------

685
        >>> g = DistGraph("test")
Da Zheng's avatar
Da Zheng committed
686
687
688
        >>> g.ntypes
        ['_U']
        """
689
        return self._gpb.ntypes
Da Zheng's avatar
Da Zheng committed
690
691
692
693
694
695
696
697
698
699
700
701

    @property
    def etypes(self):
        """Return the list of edge types of this graph.

        Returns
        -------
        list of str

        Examples
        --------

702
        >>> g = DistGraph("test")
Da Zheng's avatar
Da Zheng committed
703
704
705
706
        >>> g.etypes
        ['_E']
        """
        # Currently, we only support a graph with one edge type.
707
708
        return self._gpb.etypes

709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
    @property
    def canonical_etypes(self):
        """Return all the canonical edge types in the graph.

        A canonical edge type is a string triplet ``(str, str, str)``
        for source node type, edge type and destination node type.

        Returns
        -------
        list[(str, str, str)]
            All the canonical edge type triplets in a list.

        Notes
        -----
        DGL internally assigns an integer ID for each edge type. The returned
        edge type names are sorted according to their IDs.

        See Also
        --------
        etypes

        Examples
        --------
        The following example uses PyTorch backend.

        >>> import dgl
        >>> import torch

        >>> g = DistGraph("test")
        >>> g.canonical_etypes
        [('user', 'follows', 'user'),
         ('user', 'follows', 'game'),
         ('user', 'plays', 'game')]
        """
        return self._canonical_etypes

745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
    def to_canonical_etype(self, etype):
        """Convert an edge type to the corresponding canonical edge type in the graph.

        A canonical edge type is a string triplet ``(str, str, str)``
        for source node type, edge type and destination node type.

        The function expects the given edge type name can uniquely identify a canonical edge
        type. DGL will raise error if this is not the case.

        Parameters
        ----------
        etype : str or (str, str, str)
            If :attr:`etype` is an edge type (str), it returns the corresponding canonical edge
            type in the graph. If :attr:`etype` is already a canonical edge type,
            it directly returns the input unchanged.

        Returns
        -------
        (str, str, str)
            The canonical edge type corresponding to the edge type.

        Examples
        --------
        The following example uses PyTorch backend.

        >>> import dgl
        >>> import torch

        >>> g = DistGraph("test")
        >>> g.canonical_etypes
        [('user', 'follows', 'user'),
         ('user', 'follows', 'game'),
         ('user', 'plays', 'game')]

        >>> g.to_canonical_etype('plays')
        ('user', 'plays', 'game')
        >>> g.to_canonical_etype(('user', 'plays', 'game'))
        ('user', 'plays', 'game')

        See Also
        --------
        canonical_etypes
        """
        if etype is None:
            if len(self.etypes) != 1:
                raise DGLError('Edge type name must be specified if there are more than one '
                               'edge types.')
            etype = self.etypes[0]
        if isinstance(etype, tuple):
            return etype
        else:
            ret = self._etype2canonical.get(etype, None)
            if ret is None:
                raise DGLError('Edge type "{}" does not exist.'.format(etype))
            if len(ret) != 3:
                raise DGLError('Edge type "{}" is ambiguous. Please use canonical edge type '
                               'in the form of (srctype, etype, dsttype)'.format(etype))
            return ret

804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
    def get_ntype_id(self, ntype):
        """Return the ID of the given node type.

        ntype can also be None. If so, there should be only one node type in the
        graph.

        Parameters
        ----------
        ntype : str
            Node type

        Returns
        -------
        int
        """
        if ntype is None:
            if len(self._ntype_map) != 1:
                raise DGLError('Node type name must be specified if there are more than one '
                               'node types.')
            return 0
        return self._ntype_map[ntype]

    def get_etype_id(self, etype):
        """Return the id of the given edge type.
Da Zheng's avatar
Da Zheng committed
828

829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
        etype can also be None. If so, there should be only one edge type in the
        graph.

        Parameters
        ----------
        etype : str or tuple of str
            Edge type

        Returns
        -------
        int
        """
        if etype is None:
            if len(self._etype_map) != 1:
                raise DGLError('Edge type name must be specified if there are more than one '
                               'edge types.')
            return 0
        return self._etype_map[etype]

    def number_of_nodes(self, ntype=None):
849
        """Alias of :func:`num_nodes`"""
850
        return self.num_nodes(ntype)
851

852
    def number_of_edges(self, etype=None):
853
        """Alias of :func:`num_edges`"""
854
        return self.num_edges(etype)
855

856
    def num_nodes(self, ntype=None):
857
858
        """Return the total number of nodes in the distributed graph.

859
860
861
862
863
864
        Parameters
        ----------
        ntype : str, optional
            The node type name. If given, it returns the number of nodes of the
            type. If not given (default), it returns the total number of nodes of all types.

865
866
867
868
869
870
871
872
        Returns
        -------
        int
            The number of nodes

        Examples
        --------
        >>> g = dgl.distributed.DistGraph('ogb-product')
873
        >>> print(g.num_nodes())
874
875
        2449029
        """
876
877
878
879
880
881
882
883
        if ntype is None:
            if len(self.ntypes) == 1:
                return self._gpb._num_nodes(self.ntypes[0])
            else:
                return sum([self._gpb._num_nodes(ntype) for ntype in self.ntypes])
        return self._gpb._num_nodes(ntype)

    def num_edges(self, etype=None):
884
885
        """Return the total number of edges in the distributed graph.

886
887
888
889
890
891
892
893
894
895
896
897
        Parameters
        ----------
        etype : str or (str, str, str), optional
            The type name of the edges. The allowed type name formats are:

            * ``(str, str, str)`` for source node type, edge type and destination node type.
            * or one ``str`` edge type name if the name can uniquely identify a
              triplet format in the graph.

            If not provided, return the total number of edges regardless of the types
            in the graph.

898
899
900
901
902
903
904
905
        Returns
        -------
        int
            The number of edges

        Examples
        --------
        >>> g = dgl.distributed.DistGraph('ogb-product')
906
        >>> print(g.num_edges())
907
908
        123718280
        """
909
910
911
912
913
914
        if etype is None:
            if len(self.etypes) == 1:
                return self._gpb._num_edges(self.etypes[0])
            else:
                return sum([self._gpb._num_edges(etype) for etype in self.etypes])
        return self._gpb._num_edges(etype)
915

916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
    def out_degrees(self, u=ALL):
        """Return the out-degree(s) of the given nodes.

        It computes the out-degree(s).
        It does not support heterogeneous graphs yet.

        Parameters
        ----------
        u : node IDs
            The node IDs. The allowed formats are:

            * ``int``: A single node.
            * Int Tensor: Each element is a node ID. The tensor must have the same device type
              and ID data type as the graph's.
            * iterable[int]: Each element is a node ID.

            If not given, return the in-degrees of all the nodes.

        Returns
        -------
        int or Tensor
            The out-degree(s) of the node(s) in a Tensor. The i-th element is the out-degree
            of the i-th input node. If :attr:`v` is an ``int``, return an ``int`` too.

        Examples
        --------
        The following example uses PyTorch backend.

        >>> import dgl
        >>> import torch

        Query for all nodes.

        >>> g.out_degrees()
        tensor([2, 2, 0, 0])

        Query for nodes 1 and 2.

        >>> g.out_degrees(torch.tensor([1, 2]))
        tensor([2, 0])

        See Also
        --------
        in_degrees
        """
        if is_all(u):
            u = F.arange(0, self.number_of_nodes())
        return dist_out_degrees(self, u)

    def in_degrees(self, v=ALL):
        """Return the in-degree(s) of the given nodes.

        It computes the in-degree(s).
        It does not support heterogeneous graphs yet.

        Parameters
        ----------
        v : node IDs
            The node IDs. The allowed formats are:

            * ``int``: A single node.
            * Int Tensor: Each element is a node ID. The tensor must have the same device type
              and ID data type as the graph's.
            * iterable[int]: Each element is a node ID.

            If not given, return the in-degrees of all the nodes.

        Returns
        -------
        int or Tensor
            The in-degree(s) of the node(s) in a Tensor. The i-th element is the in-degree
            of the i-th input node. If :attr:`v` is an ``int``, return an ``int`` too.

        Examples
        --------
        The following example uses PyTorch backend.

        >>> import dgl
        >>> import torch

        Query for all nodes.

        >>> g.in_degrees()
        tensor([0, 2, 1, 1])

        Query for nodes 1 and 2.

        >>> g.in_degrees(torch.tensor([1, 2]))
        tensor([2, 1])

        See Also
        --------
        out_degrees
        """
        if is_all(v):
            v = F.arange(0, self.number_of_nodes())
        return dist_in_degrees(self, v)

1014
    def node_attr_schemes(self):
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
        """Return the node feature schemes.

        Each feature scheme is a named tuple that stores the shape and data type
        of the node feature.

        Returns
        -------
        dict of str to schemes
            The schemes of node feature columns.

        Examples
        --------
        The following uses PyTorch backend.

        >>> g.node_attr_schemes()
        {'h': Scheme(shape=(4,), dtype=torch.float32)}

        See Also
        --------
        edge_attr_schemes
        """
1036
1037
1038
1039
1040
1041
        schemes = {}
        for key in self.ndata:
            schemes[key] = infer_scheme(self.ndata[key])
        return schemes

    def edge_attr_schemes(self):
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
        """Return the edge feature schemes.

        Each feature scheme is a named tuple that stores the shape and data type
        of the edge feature.

        Returns
        -------
        dict of str to schemes
            The schemes of edge feature columns.

        Examples
        --------
        The following uses PyTorch backend.

        >>> g.edge_attr_schemes()
        {'h': Scheme(shape=(4,), dtype=torch.float32)}

        See Also
        --------
        node_attr_schemes
        """
1063
1064
1065
1066
1067
1068
        schemes = {}
        for key in self.edata:
            schemes[key] = infer_scheme(self.edata[key])
        return schemes

    def rank(self):
1069
1070
1071
1072
        ''' The rank of the current DistGraph.

        This returns a unique number to identify the DistGraph object among all of
        the client processes.
1073
1074
1075
1076

        Returns
        -------
        int
1077
            The rank of the current DistGraph.
1078
        '''
Da Zheng's avatar
Da Zheng committed
1079
        return role.get_global_rank()
1080

1081
    def find_edges(self, edges, etype=None):
1082
1083
1084
1085
1086
1087
        """ Given an edge ID array, return the source
        and destination node ID array ``s`` and ``d``.  ``s[i]`` and ``d[i]``
        are source and destination node ID for edge ``eid[i]``.

        Parameters
        ----------
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
        edges : Int Tensor
            Each element is an ID. The tensor must have the same device type
              and ID data type as the graph's.

        etype : str or (str, str, str), optional
            The type names of the edges. The allowed type name formats are:

            * ``(str, str, str)`` for source node type, edge type and destination node type.
            * or one ``str`` edge type name if the name can uniquely identify a
              triplet format in the graph.

            Can be omitted if the graph has only one type of edges.
1100
1101
1102
1103
1104
1105
1106
1107

        Returns
        -------
        tensor
            The source node ID array.
        tensor
            The destination node ID array.
        """
1108
1109
1110
1111
1112
        if etype is None:
            assert len(self.etypes) == 1, 'find_edges requires etype for heterogeneous graphs.'

        gpb = self.get_partition_book()
        if len(gpb.etypes) > 1:
1113
1114
1115
            # if etype is a canonical edge type (str, str, str), extract the edge type
            if len(etype) == 3:
                etype = etype[1]
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
            edges = gpb.map_to_homo_eid(edges, etype)
        src, dst = dist_find_edges(self, edges)
        if len(gpb.ntypes) > 1:
            _, src = gpb.map_to_per_ntype(src)
            _, dst = gpb.map_to_per_ntype(dst)
        return src, dst

    def edge_subgraph(self, edges, relabel_nodes=True, store_ids=True):
        """Return a subgraph induced on the given edges.

        An edge-induced subgraph is equivalent to creating a new graph using the given
        edges. In addition to extracting the subgraph, DGL also copies the features
        of the extracted nodes and edges to the resulting graph. The copy is *lazy*
        and incurs data movement only when needed.

        If the graph is heterogeneous, DGL extracts a subgraph per relation and composes
        them as the resulting graph. Thus, the resulting graph has the same set of relations
        as the input one.

        Parameters
        ----------
        edges : Int Tensor or dict[(str, str, str), Int Tensor]
            The edges to form the subgraph. Each element is an edge ID. The tensor must have
            the same device type and ID data type as the graph's.

            If the graph is homogeneous, one can directly pass an Int Tensor.
            Otherwise, the argument must be a dictionary with keys being edge types
            and values being the edge IDs in the above formats.
        relabel_nodes : bool, optional
            If True, it will remove the isolated nodes and relabel the incident nodes in the
            extracted subgraph.
        store_ids : bool, optional
            If True, it will store the raw IDs of the extracted edges in the ``edata`` of the
            resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will
            also store the raw IDs of the incident nodes in the ``ndata`` of the resulting
            graph under name ``dgl.NID``.

        Returns
        -------
        G : DGLGraph
            The subgraph.
        """
        if isinstance(edges, dict):
            # TODO(zhengda) we need to directly generate subgraph of all relations with
            # one invocation.
            if isinstance(edges, tuple):
                subg = {etype: self.find_edges(edges[etype], etype[1]) for etype in edges}
            else:
                subg = {}
                for etype in edges:
                    assert len(self._etype2canonical[etype]) == 3, \
                            'the etype in input edges is ambiguous'
                    subg[self._etype2canonical[etype]] = self.find_edges(edges[etype], etype)
            num_nodes = {ntype: self.number_of_nodes(ntype) for ntype in self.ntypes}
            subg = dgl_heterograph(subg, num_nodes_dict=num_nodes)
1171
1172
            for etype in edges:
                subg.edges[etype].data[EID] = edges[etype]
1173
1174
1175
1176
        else:
            assert len(self.etypes) == 1
            subg = self.find_edges(edges)
            subg = dgl_graph(subg, num_nodes=self.number_of_nodes())
1177
            subg.edata[EID] = edges
1178
1179
1180
1181
1182

        if relabel_nodes:
            subg = compact_graphs(subg)
        assert store_ids, 'edge_subgraph always stores original node/edge IDs.'
        return subg
1183

1184
1185
1186
1187
1188
1189
    def get_partition_book(self):
        """Get the partition information.

        Returns
        -------
        GraphPartitionBook
1190
            Object that stores all graph partition information.
1191
1192
        """
        return self._gpb
1193

1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
    def get_node_partition_policy(self, ntype):
        """Get the partition policy for a node type.

        When creating a new distributed tensor, we need to provide a partition policy
        that indicates how to distribute data of the distributed tensor in a cluster
        of machines. When we load a distributed graph in the cluster, we have pre-defined
        partition policies for each node type and each edge type. By providing
        the node type, we can reference to the pre-defined partition policy for the node type.

        Parameters
        ----------
        ntype : str
            The node type

        Returns
        -------
        PartitionPolicy
            The partition policy for the node type.
        """
        return NodePartitionPolicy(self.get_partition_book(), ntype)

    def get_edge_partition_policy(self, etype):
        """Get the partition policy for an edge type.

        When creating a new distributed tensor, we need to provide a partition policy
        that indicates how to distribute data of the distributed tensor in a cluster
        of machines. When we load a distributed graph in the cluster, we have pre-defined
        partition policies for each node type and each edge type. By providing
        the edge type, we can reference to the pre-defined partition policy for the edge type.

        Parameters
        ----------
        etype : str
            The edge type

        Returns
        -------
        PartitionPolicy
            The partition policy for the edge type.
        """
        return EdgePartitionPolicy(self.get_partition_book(), etype)

1236
1237
1238
    def barrier(self):
        '''Barrier for all client nodes.

1239
1240
        This API blocks the current process untill all the clients invoke this API.
        Please use this API with caution.
1241
1242
1243
        '''
        self._client.barrier()

1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
    def sample_neighbors(self, seed_nodes, fanout, edge_dir='in', prob=None,
                         exclude_edges=None, replace=False,
                         output_device=None):
        # pylint: disable=unused-argument
        """Sample neighbors from a distributed graph."""
        # Currently prob, exclude_edges, output_device, and edge_dir are ignored.
        if len(self.etypes) > 1:
            frontier = graph_services.sample_etype_neighbors(
                self, seed_nodes, ETYPE, fanout, replace=replace)
        else:
            frontier = graph_services.sample_neighbors(
                self, seed_nodes, fanout, replace=replace)
        return frontier

1258
    def _get_ndata_names(self, ntype=None):
1259
1260
        ''' Get the names of all node data.
        '''
1261
        names = self._client.gdata_name_list()
1262
1263
        ndata_names = []
        for name in names:
1264
1265
1266
1267
            name = parse_hetero_data_name(name)
            right_type = (name.get_type() == ntype) if ntype is not None else True
            if name.is_node() and right_type:
                ndata_names.append(name)
1268
1269
        return ndata_names

1270
    def _get_edata_names(self, etype=None):
1271
1272
        ''' Get the names of all edge data.
        '''
1273
        names = self._client.gdata_name_list()
1274
1275
        edata_names = []
        for name in names:
1276
1277
1278
1279
            name = parse_hetero_data_name(name)
            right_type = (name.get_type() == etype) if etype is not None else True
            if name.is_edge() and right_type:
                edata_names.append(name)
1280
        return edata_names
1281
1282

def _get_overlap(mask_arr, ids):
1283
    """ Select the IDs given a boolean mask array.
1284

1285
1286
1287
    The boolean mask array indicates all of the IDs to be selected. We want to
    find the overlap between the IDs selected by the boolean mask array and
    the ID array.
1288
1289
1290
1291
1292
1293

    Parameters
    ----------
    mask_arr : 1D tensor
        A boolean mask array.
    ids : 1D tensor
1294
        A vector with IDs.
1295
1296
1297
1298

    Returns
    -------
    1D tensor
1299
        The selected IDs.
1300
1301
1302
1303
1304
    """
    if isinstance(mask_arr, DistTensor):
        masks = mask_arr[ids]
        return F.boolean_mask(ids, masks)
    else:
1305
        masks = F.gather_row(F.tensor(mask_arr), ids)
1306
1307
        return F.boolean_mask(ids, masks)

1308
1309
1310
def _split_local(partition_book, rank, elements, local_eles):
    ''' Split the input element list with respect to data locality.
    '''
Da Zheng's avatar
Da Zheng committed
1311
    num_clients = role.get_num_trainers()
1312
1313
    num_client_per_part = num_clients // partition_book.num_partitions()
    if rank is None:
Da Zheng's avatar
Da Zheng committed
1314
1315
1316
        rank = role.get_trainer_rank()
    assert rank < num_clients, \
            'The input rank ({}) is incorrect. #Trainers: {}'.format(rank, num_clients)
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
    # all ranks of the clients in the same machine are in a contiguous range.
    client_id_in_part = rank  % num_client_per_part
    local_eles = _get_overlap(elements, local_eles)

    # get a subset for the local client.
    size = len(local_eles) // num_client_per_part
    # if this isn't the last client in the partition.
    if client_id_in_part + 1 < num_client_per_part:
        return local_eles[(size * client_id_in_part):(size * (client_id_in_part + 1))]
    else:
        return local_eles[(size * client_id_in_part):]

1329
1330
1331
def _even_offset(n, k):
    ''' Split an array of length n into k segments and the difference of thier length is
        at most 1. Return the offset of each segment.
1332
    '''
1333
1334
1335
1336
    eles_per_part = n // k
    offset = np.array([0] + [eles_per_part] * k, dtype=int)
    offset[1 : n - eles_per_part * k + 1] += 1
    return np.cumsum(offset)
1337

1338
1339
1340
def _split_even_to_part(partition_book, elements):
    ''' Split the input element list evenly.
    '''
1341
1342
1343
1344
    # here we divide the element list as evenly as possible. If we use range partitioning,
    # the split results also respect the data locality. Range partitioning is the default
    # strategy.
    # TODO(zhengda) we need another way to divide the list for other partitioning strategy.
1345
    if isinstance(elements, DistTensor):
1346
        nonzero_count = elements.count_nonzero()
1347
    else:
1348
1349
        elements = F.tensor(elements)
        nonzero_count = F.count_nonzero(elements)
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
    # compute the offset of each split and ensure that the difference of each partition size
    # is 1.
    offsets = _even_offset(nonzero_count, partition_book.num_partitions())
    assert offsets[-1] == nonzero_count

    # Get the elements that belong to the partition.
    partid = partition_book.partid
    left, right = offsets[partid], offsets[partid + 1]

    x = y = 0
    num_elements = len(elements)
    block_size = num_elements // partition_book.num_partitions()
    part_eles = None
    # compute the nonzero tensor of each partition instead of whole tensor to save memory
    for idx in range(0, num_elements, block_size):
        nonzero_block = F.nonzero_1d(elements[idx:min(idx+block_size, num_elements)])
        x = y
        y += len(nonzero_block)
        if y > left and x < right:
            start = max(x, left) - x
            end = min(y, right) - x
            tmp = nonzero_block[start:end] + idx
            if part_eles is None:
                part_eles = tmp
            else:
                part_eles = F.cat((part_eles, tmp), 0)
        elif x >= right:
            break
1378
1379

    return part_eles
1380

1381
1382
1383
1384
def _split_random_within_part(partition_book, rank, part_eles):
    # If there are more than one client in a partition, we need to randomly select a subset of
    # elements in the partition for a client. We have to make sure that the set of elements
    # for different clients are disjoint.
1385

1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
    num_clients = role.get_num_trainers()
    num_client_per_part = num_clients // partition_book.num_partitions()
    if num_client_per_part == 1:
        return part_eles
    if rank is None:
        rank = role.get_trainer_rank()
    assert rank < num_clients, \
            'The input rank ({}) is incorrect. #Trainers: {}'.format(rank, num_clients)
    client_id_in_part = rank  % num_client_per_part
    offset = _even_offset(len(part_eles), num_client_per_part)

    # We set the random seed for each partition, so that each process (client) in a partition
    # permute the elements in a partition in the same way, so each process gets a disjoint subset
    # of elements.
    np.random.seed(partition_book.partid)
    rand_idx = np.random.permutation(len(part_eles))
    rand_idx = rand_idx[offset[client_id_in_part] : offset[client_id_in_part + 1]]
    idx, _ = F.sort_1d(F.tensor(rand_idx))
    return F.gather_row(part_eles, idx)

def _split_by_trainer_id(partition_book, part_eles, trainer_id,
                         num_client_per_part, client_id_in_part):
    # TODO(zhengda): MXNet cannot deal with empty tensors, which makes the implementation
    # much more difficult. Let's just use numpy for the computation for now. We just
    # perform operations on vectors. It shouldn't be too difficult.
    trainer_id = F.asnumpy(trainer_id)
    part_eles = F.asnumpy(part_eles)
    part_id = trainer_id // num_client_per_part
    trainer_id = trainer_id % num_client_per_part
    local_eles = part_eles[np.nonzero(part_id[part_eles] == partition_book.partid)[0]]
    # these are the Ids of the local elements in the partition. The Ids are global Ids.
    remote_eles = part_eles[np.nonzero(part_id[part_eles] != partition_book.partid)[0]]
    # these are the Ids of the remote nodes in the partition. The Ids are global Ids.
    local_eles_idx = np.concatenate(
        [np.nonzero(trainer_id[local_eles] == i)[0] for i in range(num_client_per_part)],
        # trainer_id[local_eles] is the trainer ids of local nodes in the partition and we
        # pick out the indices where the node belongs to each trainer i respectively, and
        # concatenate them.
        axis=0
    )
    # `local_eles_idx` is used to sort `local_eles` according to `trainer_id`. It is a
    # permutation of 0...(len(local_eles)-1)
    local_eles = local_eles[local_eles_idx]

    # evenly split local nodes to trainers
    local_offsets = _even_offset(len(local_eles), num_client_per_part)
    # evenly split remote nodes to trainers
    remote_offsets = _even_offset(len(remote_eles), num_client_per_part)

    client_local_eles = local_eles[
        local_offsets[client_id_in_part]:local_offsets[client_id_in_part + 1]]
    client_remote_eles = remote_eles[
        remote_offsets[client_id_in_part]:remote_offsets[client_id_in_part + 1]]
    client_eles = np.concatenate([client_local_eles, client_remote_eles], axis=0)
    return F.tensor(client_eles)

def node_split(nodes, partition_book=None, ntype='_N', rank=None, force_even=True,
               node_trainer_ids=None):
1444
1445
1446
1447
1448
1449
    ''' Split nodes and return a subset for the local rank.

    This function splits the input nodes based on the partition book and
    returns a subset of nodes for the local rank. This method is used for
    dividing workloads for distributed training.

1450
    The input nodes are stored as a vector of masks. The length of the vector is
1451
1452
1453
    the same as the number of nodes in a graph; 1 indicates that the vertex in
    the corresponding location exists.

1454
1455
    There are two strategies to split the nodes. By default, it splits the nodes
    in a way to maximize data locality. That is, all nodes that belong to a process
1456
    are returned. If ``force_even`` is set to true, the nodes are split evenly so
1457
1458
    that each process gets almost the same number of nodes.

1459
    When ``force_even`` is True, the data locality is still preserved if a graph is partitioned
1460
    with Metis and the node/edge IDs are shuffled.
1461
    In this case, majority of the nodes returned for a process are the ones that
1462
    belong to the process. If node/edge IDs are not shuffled, data locality is not guaranteed.
1463

1464
1465
1466
1467
    Parameters
    ----------
    nodes : 1D tensor or DistTensor
        A boolean mask vector that indicates input nodes.
1468
    partition_book : GraphPartitionBook, optional
1469
        The graph partition book
1470
1471
1472
    ntype : str, optional
        The node type of the input nodes.
    rank : int, optional
1473
        The rank of a process. If not given, the rank of the current process is used.
1474
    force_even : bool, optional
1475
        Force the nodes are split evenly.
1476
1477
1478
    node_trainer_ids : 1D tensor or DistTensor, optional
        If not None, split the nodes to the trainers on the same machine according to
        trainer IDs assigned to each node. Otherwise, split randomly.
1479
1480
1481
1482

    Returns
    -------
    1D-tensor
1483
        The vector of node IDs that belong to the rank.
1484
    '''
Da Zheng's avatar
Da Zheng committed
1485
1486
1487
1488
    if not isinstance(nodes, DistTensor):
        assert partition_book is not None, 'Regular tensor requires a partition book.'
    elif partition_book is None:
        partition_book = nodes.part_policy.partition_book
1489
1490

    assert len(nodes) == partition_book._num_nodes(ntype), \
1491
            'The length of boolean mask vector should be the number of nodes in the graph.'
1492
1493
    if rank is None:
        rank = role.get_trainer_rank()
1494
    if force_even:
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
        num_clients = role.get_num_trainers()
        num_client_per_part = num_clients // partition_book.num_partitions()
        assert num_clients % partition_book.num_partitions() == 0, \
                'The total number of clients should be multiple of the number of partitions.'
        part_nid = _split_even_to_part(partition_book, nodes)
        if num_client_per_part == 1:
            return part_nid
        elif node_trainer_ids is None:
            return _split_random_within_part(partition_book, rank, part_nid)
        else:
            trainer_id = node_trainer_ids[0:len(node_trainer_ids)]
            max_trainer_id = F.as_scalar(F.reduce_max(trainer_id)) + 1

            if max_trainer_id > num_clients:
                # We hope the partition scheme with trainer_id could be used when the number of
                # trainers is less than the `num_trainers_per_machine` previously assigned during
                # partitioning.
                assert max_trainer_id % num_clients == 0
                trainer_id //= (max_trainer_id // num_clients)

            client_id_in_part = rank % num_client_per_part
            return _split_by_trainer_id(partition_book, part_nid, trainer_id,
                                        num_client_per_part, client_id_in_part)
1518
1519
    else:
        # Get all nodes that belong to the rank.
1520
        local_nids = partition_book.partid2nids(partition_book.partid, ntype=ntype)
1521
        return _split_local(partition_book, rank, nodes, local_nids)
1522

1523
1524
def edge_split(edges, partition_book=None, etype='_E', rank=None, force_even=True,
               edge_trainer_ids=None):
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
    ''' Split edges and return a subset for the local rank.

    This function splits the input edges based on the partition book and
    returns a subset of edges for the local rank. This method is used for
    dividing workloads for distributed training.

    The input edges can be stored as a vector of masks. The length of the vector is
    the same as the number of edges in a graph; 1 indicates that the edge in
    the corresponding location exists.

1535
1536
    There are two strategies to split the edges. By default, it splits the edges
    in a way to maximize data locality. That is, all edges that belong to a process
1537
    are returned. If ``force_even`` is set to true, the edges are split evenly so
1538
1539
    that each process gets almost the same number of edges.

1540
    When ``force_even`` is True, the data locality is still preserved if a graph is partitioned
1541
1542
1543
    with Metis and the node/edge IDs are shuffled.
    In this case, majority of the nodes returned for a process are the ones that
    belong to the process. If node/edge IDs are not shuffled, data locality is not guaranteed.
1544

1545
1546
1547
    Parameters
    ----------
    edges : 1D tensor or DistTensor
1548
        A boolean mask vector that indicates input edges.
1549
    partition_book : GraphPartitionBook, optional
1550
        The graph partition book
1551
1552
1553
    etype : str, optional
        The edge type of the input edges.
    rank : int, optional
1554
        The rank of a process. If not given, the rank of the current process is used.
1555
    force_even : bool, optional
1556
        Force the edges are split evenly.
1557
1558
1559
    edge_trainer_ids : 1D tensor or DistTensor, optional
        If not None, split the edges to the trainers on the same machine according to
        trainer IDs assigned to each edge. Otherwise, split randomly.
1560
1561
1562
1563

    Returns
    -------
    1D-tensor
1564
        The vector of edge IDs that belong to the rank.
1565
    '''
Da Zheng's avatar
Da Zheng committed
1566
1567
1568
1569
    if not isinstance(edges, DistTensor):
        assert partition_book is not None, 'Regular tensor requires a partition book.'
    elif partition_book is None:
        partition_book = edges.part_policy.partition_book
1570
    assert len(edges) == partition_book._num_edges(etype), \
1571
            'The length of boolean mask vector should be the number of edges in the graph.'
1572
1573
    if rank is None:
        rank = role.get_trainer_rank()
1574
    if force_even:
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
        num_clients = role.get_num_trainers()
        num_client_per_part = num_clients // partition_book.num_partitions()
        assert num_clients % partition_book.num_partitions() == 0, \
                'The total number of clients should be multiple of the number of partitions.'
        part_eid = _split_even_to_part(partition_book, edges)
        if num_client_per_part == 1:
            return part_eid
        elif edge_trainer_ids is None:
            return _split_random_within_part(partition_book, rank, part_eid)
        else:
            trainer_id = edge_trainer_ids[0:len(edge_trainer_ids)]
            max_trainer_id = F.as_scalar(F.reduce_max(trainer_id)) + 1

            if max_trainer_id > num_clients:
                # We hope the partition scheme with trainer_id could be used when the number of
                # trainers is less than the `num_trainers_per_machine` previously assigned during
                # partitioning.
                assert max_trainer_id % num_clients == 0
                trainer_id //= (max_trainer_id // num_clients)

            client_id_in_part = rank % num_client_per_part
            return _split_by_trainer_id(partition_book, part_eid, trainer_id,
                                        num_client_per_part, client_id_in_part)
1598
1599
    else:
        # Get all edges that belong to the rank.
1600
        local_eids = partition_book.partid2eids(partition_book.partid, etype=etype)
1601
        return _split_local(partition_book, rank, edges, local_eids)
1602
1603

rpc.register_service(INIT_GRAPH, InitGraphRequest, InitGraphResponse)