dist_graph.py 61.3 KB
Newer Older
1
2
3
"""Define distributed graph."""

from collections.abc import MutableMapping
4
5
from collections import namedtuple

6
import os
7
import numpy as np
8

9
from ..heterograph import DGLHeteroGraph
10
11
from ..convert import heterograph as dgl_heterograph
from ..convert import graph as dgl_graph
12
from ..transforms import compact_graphs
13
from .. import heterograph_index
14
from .. import backend as F
15
from ..base import NID, EID, NTYPE, ETYPE, ALL, is_all
Da Zheng's avatar
Da Zheng committed
16
from .kvstore import KVServer, get_kvstore
17
from .._ffi.ndarray import empty_shared_mem
18
from ..ndarray import exist_shared_mem_array
19
from ..frame import infer_scheme
20
from .partition import load_partition, load_partition_feats, load_partition_book
Da Zheng's avatar
Da Zheng committed
21
from .graph_partition_book import PartitionPolicy, get_shared_mem_partition_book
22
23
from .graph_partition_book import HeteroDataName, parse_hetero_data_name
from .graph_partition_book import NodePartitionPolicy, EdgePartitionPolicy
24
from .shared_mem_utils import _to_shared_mem, _get_ndata_path, _get_edata_path, DTYPE_DICT
25
from . import rpc
Da Zheng's avatar
Da Zheng committed
26
from . import role
27
28
from .server_state import ServerState
from .rpc_server import start_server
29
from . import graph_services
30
from .graph_services import find_edges as dist_find_edges
31
32
from .graph_services import out_degrees as dist_out_degrees
from .graph_services import in_degrees as dist_in_degrees
33
from .dist_tensor import DistTensor
34

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
INIT_GRAPH = 800001

class InitGraphRequest(rpc.Request):
    """ Init graph on the backup servers.

    When the backup server starts, they don't load the graph structure.
    This request tells the backup servers that they can map to the graph structure
    with shared memory.
    """
    def __init__(self, graph_name):
        self._graph_name = graph_name

    def __getstate__(self):
        return self._graph_name

    def __setstate__(self, state):
        self._graph_name = state

    def process_request(self, server_state):
        if server_state.graph is None:
            server_state.graph = _get_graph_from_shared_mem(self._graph_name)
        return InitGraphResponse(self._graph_name)

class InitGraphResponse(rpc.Response):
    """ Ack the init graph request
    """
    def __init__(self, graph_name):
        self._graph_name = graph_name

    def __getstate__(self):
        return self._graph_name

    def __setstate__(self, state):
        self._graph_name = state

70
71
def _copy_graph_to_shared_mem(g, graph_name, graph_format):
    new_g = g.shared_memory(graph_name, formats=graph_format)
72
73
    # We should share the node/edge data to the client explicitly instead of putting them
    # in the KVStore because some of the node/edge data may be duplicated.
74
75
    new_g.ndata['inner_node'] = _to_shared_mem(g.ndata['inner_node'],
                                               _get_ndata_path(graph_name, 'inner_node'))
76
    new_g.ndata[NID] = _to_shared_mem(g.ndata[NID], _get_ndata_path(graph_name, NID))
77
78
79

    new_g.edata['inner_edge'] = _to_shared_mem(g.edata['inner_edge'],
                                               _get_edata_path(graph_name, 'inner_edge'))
80
    new_g.edata[EID] = _to_shared_mem(g.edata[EID], _get_edata_path(graph_name, EID))
81
82
83
84
    # for heterogeneous graph, we need to put ETYPE into KVStore
    # for homogeneous graph, ETYPE does not exist
    if ETYPE in g.edata:
        new_g.edata[ETYPE] = _to_shared_mem(g.edata[ETYPE], _get_edata_path(graph_name, ETYPE))
85
86
    return new_g

87
88
FIELD_DICT = {'inner_node': F.int32,    # A flag indicates whether the node is inside a partition.
              'inner_edge': F.int32,    # A flag indicates whether the edge is inside a partition.
89
              NID: F.int64,
90
              EID: F.int64,
91
92
              NTYPE: F.int32,
              ETYPE: F.int32}
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119

def _get_shared_mem_ndata(g, graph_name, name):
    ''' Get shared-memory node data from DistGraph server.

    This is called by the DistGraph client to access the node data in the DistGraph server
    with shared memory.
    '''
    shape = (g.number_of_nodes(),)
    dtype = FIELD_DICT[name]
    dtype = DTYPE_DICT[dtype]
    data = empty_shared_mem(_get_ndata_path(graph_name, name), False, shape, dtype)
    dlpack = data.to_dlpack()
    return F.zerocopy_from_dlpack(dlpack)

def _get_shared_mem_edata(g, graph_name, name):
    ''' Get shared-memory edge data from DistGraph server.

    This is called by the DistGraph client to access the edge data in the DistGraph server
    with shared memory.
    '''
    shape = (g.number_of_edges(),)
    dtype = FIELD_DICT[name]
    dtype = DTYPE_DICT[dtype]
    data = empty_shared_mem(_get_edata_path(graph_name, name), False, shape, dtype)
    dlpack = data.to_dlpack()
    return F.zerocopy_from_dlpack(dlpack)

120
121
122
def _exist_shared_mem_array(graph_name, name):
    return exist_shared_mem_array(_get_edata_path(graph_name, name))

123
124
125
126
127
128
129
def _get_graph_from_shared_mem(graph_name):
    ''' Get the graph from the DistGraph server.

    The DistGraph server puts the graph structure of the local partition in the shared memory.
    The client can access the graph structure and some metadata on nodes and edges directly
    through shared memory to reduce the overhead of data access.
    '''
130
131
132
133
    g, ntypes, etypes = heterograph_index.create_heterograph_from_shared_memory(graph_name)
    if g is None:
        return None
    g = DGLHeteroGraph(g, ntypes, etypes)
134

Da Zheng's avatar
Da Zheng committed
135
    g.ndata['inner_node'] = _get_shared_mem_ndata(g, graph_name, 'inner_node')
136
    g.ndata[NID] = _get_shared_mem_ndata(g, graph_name, NID)
137
138

    g.edata['inner_edge'] = _get_shared_mem_edata(g, graph_name, 'inner_edge')
139
    g.edata[EID] = _get_shared_mem_edata(g, graph_name, EID)
140
141
142
143

    # heterogeneous graph has ETYPE
    if _exist_shared_mem_array(graph_name, ETYPE):
        g.edata[ETYPE] = _get_shared_mem_edata(g, graph_name, ETYPE)
144
145
    return g

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
NodeSpace = namedtuple('NodeSpace', ['data'])
EdgeSpace = namedtuple('EdgeSpace', ['data'])

class HeteroNodeView(object):
    """A NodeView class to act as G.nodes for a DistGraph."""
    __slots__ = ['_graph']

    def __init__(self, graph):
        self._graph = graph

    def __getitem__(self, key):
        assert isinstance(key, str)
        return NodeSpace(data=NodeDataView(self._graph, key))

class HeteroEdgeView(object):
    """A NodeView class to act as G.nodes for a DistGraph."""
    __slots__ = ['_graph']

    def __init__(self, graph):
        self._graph = graph

    def __getitem__(self, key):
        assert isinstance(key, str)
        return EdgeSpace(data=EdgeDataView(self._graph, key))

171
172
173
174
175
class NodeDataView(MutableMapping):
    """The data view class when dist_graph.ndata[...].data is called.
    """
    __slots__ = ['_graph', '_data']

176
    def __init__(self, g, ntype=None):
177
178
179
        self._graph = g
        # When this is created, the server may already load node data. We need to
        # initialize the node data in advance.
180
181
182
183
184
185
186
187
188
        names = g._get_ndata_names(ntype)
        if ntype is None:
            self._data = g._ndata_store
        else:
            if ntype in g._ndata_store:
                self._data = g._ndata_store[ntype]
            else:
                self._data = {}
                g._ndata_store[ntype] = self._data
189
        for name in names:
190
191
192
            assert name.is_node()
            policy = PartitionPolicy(name.policy_str, g.get_partition_book())
            dtype, shape, _ = g._client.get_data_meta(str(name))
193
            # We create a wrapper on the existing tensor in the kvstore.
194
            self._data[name.get_name()] = DistTensor(shape, dtype, name.get_name(),
195
                                                     part_policy=policy, attach=False)
196
197
198
199
200
201
202
203

    def _get_names(self):
        return list(self._data.keys())

    def __getitem__(self, key):
        return self._data[key]

    def __setitem__(self, key, val):
204
        self._data[key] = val
205
206

    def __delitem__(self, key):
207
        del self._data[key]
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229

    def __len__(self):
        # The number of node data may change. Let's count it every time we need them.
        # It's not called frequently. It should be fine.
        return len(self._data)

    def __iter__(self):
        return iter(self._data)

    def __repr__(self):
        reprs = {}
        for name in self._data:
            dtype = F.dtype(self._data[name])
            shape = F.shape(self._data[name])
            reprs[name] = 'DistTensor(shape={}, dtype={})'.format(str(shape), str(dtype))
        return repr(reprs)

class EdgeDataView(MutableMapping):
    """The data view class when G.edges[...].data is called.
    """
    __slots__ = ['_graph', '_data']

230
    def __init__(self, g, etype=None):
231
232
233
        self._graph = g
        # When this is created, the server may already load edge data. We need to
        # initialize the edge data in advance.
234
235
236
237
238
239
240
241
242
        names = g._get_edata_names(etype)
        if etype is None:
            self._data = g._edata_store
        else:
            if etype in g._edata_store:
                self._data = g._edata_store[etype]
            else:
                self._data = {}
                g._edata_store[etype] = self._data
243
        for name in names:
244
245
246
            assert name.is_edge()
            policy = PartitionPolicy(name.policy_str, g.get_partition_book())
            dtype, shape, _ = g._client.get_data_meta(str(name))
247
            # We create a wrapper on the existing tensor in the kvstore.
248
            self._data[name.get_name()] = DistTensor(shape, dtype, name.get_name(),
249
                                                     part_policy=policy, attach=False)
250
251
252
253
254
255
256
257

    def _get_names(self):
        return list(self._data.keys())

    def __getitem__(self, key):
        return self._data[key]

    def __setitem__(self, key, val):
258
        self._data[key] = val
259
260

    def __delitem__(self, key):
261
        del self._data[key]
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282

    def __len__(self):
        # The number of edge data may change. Let's count it every time we need them.
        # It's not called frequently. It should be fine.
        return len(self._data)

    def __iter__(self):
        return iter(self._data)

    def __repr__(self):
        reprs = {}
        for name in self._data:
            dtype = F.dtype(self._data[name])
            shape = F.shape(self._data[name])
            reprs[name] = 'DistTensor(shape={}, dtype={})'.format(str(shape), str(dtype))
        return repr(reprs)


class DistGraphServer(KVServer):
    ''' The DistGraph server.

283
284
285
    This DistGraph server loads the graph data and sets up a service so that trainers and
    samplers can read data of a graph partition (graph structure, node data and edge data)
    from remote machines. A server is responsible for one graph partition.
286
287
288
289
290
291

    Currently, each machine runs only one main server with a set of backup servers to handle
    clients' requests. The main server and the backup servers all handle the requests for the same
    graph partition. They all share the partition data (graph structure and node/edge data) with
    shared memory.

292
293
294
    By default, the partition data is shared with the DistGraph clients that run on
    the same machine. However, a user can disable shared memory option. This is useful for the case
    that a user wants to run the server and the client on different machines.
295
296
297
298
299

    Parameters
    ----------
    server_id : int
        The server ID (start from 0).
300
301
    ip_config : str
        Path of IP configuration file.
302
303
    num_servers : int
        Server count on each machine.
304
    num_clients : int
305
        Total number of client nodes.
306
    part_config : string
307
        The path of the config file generated by the partition tool.
308
309
    disable_shared_mem : bool
        Disable shared memory.
310
311
    graph_format : str or list of str
        The graph formats.
312
313
    keep_alive : bool
        Whether to keep server alive when clients exit
314
315
    net_type : str
        Backend rpc type: ``'socket'`` or ``'tensorpipe'``
316
    '''
317
    def __init__(self, server_id, ip_config, num_servers,
318
                 num_clients, part_config, disable_shared_mem=False,
319
                 graph_format=('csc', 'coo'), keep_alive=False,
320
                 net_type='socket'):
321
322
323
        super(DistGraphServer, self).__init__(server_id=server_id,
                                              ip_config=ip_config,
                                              num_servers=num_servers,
324
325
                                              num_clients=num_clients)
        self.ip_config = ip_config
326
        self.num_servers = num_servers
327
        self.keep_alive = keep_alive
328
        self.net_type = net_type
329
        # Load graph partition data.
330
331
        if self.is_backup_server():
            # The backup server doesn't load the graph partition. It'll initialized afterwards.
332
            self.gpb, graph_name, ntypes, etypes = load_partition_book(part_config, self.part_id)
333
334
            self.client_g = None
        else:
335
336
337
            # Loading of node/edge_feats are deferred to lower the peak memory consumption.
            self.client_g, _, _, self.gpb, graph_name, \
                    ntypes, etypes = load_partition(part_config, self.part_id, load_feats=False)
338
            print('load ' + graph_name)
339
340
341
342
343
344
345
346
347
348
349
            # formatting dtype
            # TODO(Rui) Formatting forcely is not a perfect solution.
            #   We'd better store all dtypes when mapping to shared memory
            #   and map back with original dtypes.
            for k, dtype in FIELD_DICT.items():
                if k in self.client_g.ndata:
                    self.client_g.ndata[k] = F.astype(
                        self.client_g.ndata[k], dtype)
                if k in self.client_g.edata:
                    self.client_g.edata[k] = F.astype(
                        self.client_g.edata[k], dtype)
350
351
352
            # Create the graph formats specified the users.
            self.client_g = self.client_g.formats(graph_format)
            self.client_g.create_formats_()
353
            if not disable_shared_mem:
354
                self.client_g = _copy_graph_to_shared_mem(self.client_g, graph_name, graph_format)
355

356
357
        if not disable_shared_mem:
            self.gpb.shared_memory(graph_name)
358
        assert self.gpb.partid == self.part_id
359
360
361
362
363
364
        for ntype in ntypes:
            node_name = HeteroDataName(True, ntype, None)
            self.add_part_policy(PartitionPolicy(node_name.policy_str, self.gpb))
        for etype in etypes:
            edge_name = HeteroDataName(False, etype, None)
            self.add_part_policy(PartitionPolicy(edge_name.policy_str, self.gpb))
365
366

        if not self.is_backup_server():
367
            node_feats, edge_feats = load_partition_feats(part_config, self.part_id)
368
            for name in node_feats:
369
370
371
372
373
                # The feature name has the following format: node_type + "/" + feature_name to avoid
                # feature name collision for different node types.
                ntype, feat_name = name.split('/')
                data_name = HeteroDataName(True, ntype, feat_name)
                self.init_data(name=str(data_name), policy_str=data_name.policy_str,
374
                               data_tensor=node_feats[name])
375
                self.orig_data.add(str(data_name))
376
            for name in edge_feats:
377
378
379
380
381
                # The feature name has the following format: edge_type + "/" + feature_name to avoid
                # feature name collision for different edge types.
                etype, feat_name = name.split('/')
                data_name = HeteroDataName(False, etype, feat_name)
                self.init_data(name=str(data_name), policy_str=data_name.policy_str,
382
                               data_tensor=edge_feats[name])
383
                self.orig_data.add(str(data_name))
384
385
386
387
388

    def start(self):
        """ Start graph store server.
        """
        # start server
389
390
391
392
        server_state = ServerState(kv_store=self, local_g=self.client_g,
                                   partition_book=self.gpb, keep_alive=self.keep_alive)
        print('start graph service on server {} for part {}'.format(
            self.server_id, self.part_id))
393
394
395
        start_server(server_id=self.server_id,
                     ip_config=self.ip_config,
                     num_servers=self.num_servers,
396
397
398
                     num_clients=self.num_clients,
                     server_state=server_state,
                     net_type=self.net_type)
399

400
class DistGraph:
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
    '''The class for accessing a distributed graph.

    This class provides a subset of DGLGraph APIs for accessing partitioned graph data in
    distributed GNN training and inference. Thus, its main use case is to work with
    distributed sampling APIs to generate mini-batches and perform forward and
    backward computation on the mini-batches.

    The class can run in two modes: the standalone mode and the distributed mode.

    * When a user runs the training script normally, ``DistGraph`` will be in the standalone mode.
      In this mode, the input data must be constructed by
      :py:meth:`~dgl.distributed.partition.partition_graph` with only one partition. This mode is
      used for testing and debugging purpose. In this mode, users have to provide ``part_config``
      so that ``DistGraph`` can load the input graph.
    * When a user runs the training script with the distributed launch script, ``DistGraph`` will
      be set into the distributed mode. This is used for actual distributed training. All data of
      partitions are loaded by the ``DistGraph`` servers, which are created by DGL's launch script.
      ``DistGraph`` connects with the servers to access the partitioned graph data.

    Currently, the ``DistGraph`` servers and clients run on the same set of machines
    in the distributed mode. ``DistGraph`` uses shared-memory to access the partition data
    in the local machine. This gives the best performance for distributed training

    Users may want to run ``DistGraph`` servers and clients on separate sets of machines.
    In this case, a user may want to disable shared memory by passing
    ``disable_shared_mem=False`` when creating ``DistGraphServer``. When shared memory is disabled,
427
    a user has to pass a partition book.
428
429
430
431

    Parameters
    ----------
    graph_name : str
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
        The name of the graph. This name has to be the same as the one used for
        partitioning a graph in :py:meth:`dgl.distributed.partition.partition_graph`.
    gpb : GraphPartitionBook, optional
        The partition book object. Normally, users do not need to provide the partition book.
        This argument is necessary only when users want to run server process and trainer
        processes on different machines.
    part_config : str, optional
        The path of partition configuration file generated by
        :py:meth:`dgl.distributed.partition.partition_graph`. It's used in the standalone mode.

    Examples
    --------
    The example shows the creation of ``DistGraph`` in the standalone mode.

    >>> dgl.distributed.partition_graph(g, 'graph_name', 1, num_hops=1, part_method='metis',
447
    ...                                 out_path='output/', reshuffle=True)
448
449
450
451
452
453
454
455
456
457
458
459
460
    >>> g = dgl.distributed.DistGraph('graph_name', part_config='output/graph_name.json')

    The example shows the creation of ``DistGraph`` in the distributed mode.

    >>> g = dgl.distributed.DistGraph('graph-name')

    The code below shows the mini-batch training using ``DistGraph``.

    >>> def sample(seeds):
    ...     seeds = th.LongTensor(np.asarray(seeds))
    ...     frontier = dgl.distributed.sample_neighbors(g, seeds, 10)
    ...     return dgl.to_block(frontier, seeds)
    >>> dataloader = dgl.distributed.DistDataLoader(dataset=nodes, batch_size=1000,
461
    ...                                             collate_fn=sample, shuffle=True)
462
463
464
465
466
467
468
469
470
471
    >>> for block in dataloader:
    ...     feat = g.ndata['features'][block.srcdata[dgl.NID]]
    ...     labels = g.ndata['labels'][block.dstdata[dgl.NID]]
    ...     pred = model(block, feat)

    Note
    ----
    DGL's distributed training by default runs server processes and trainer processes on the same
    set of machines. If users need to run them on different sets of machines, it requires
    manually setting up servers and trainers. The setup is not fully tested yet.
472
    '''
473
    def __init__(self, graph_name, gpb=None, part_config=None):
474
475
        self.graph_name = graph_name
        self._gpb_input = gpb
476
        if os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone':
477
            assert part_config is not None, \
478
                    'When running in the standalone model, the partition config file is required'
479
            self._client = get_kvstore()
480
481
            assert self._client is not None, \
                    'Distributed module is not initialized. Please call dgl.distributed.initialize.'
482
            # Load graph partition data.
483
            g, node_feats, edge_feats, self._gpb, _, _, _ = load_partition(part_config, 0)
484
485
486
487
            assert self._gpb.num_partitions() == 1, \
                    'The standalone mode can only work with the graph data with one partition'
            if self._gpb is None:
                self._gpb = gpb
488
            self._g = g
489
            for name in node_feats:
490
491
492
493
494
                # The feature name has the following format: node_type + "/" + feature_name.
                ntype, feat_name = name.split('/')
                self._client.add_data(str(HeteroDataName(True, ntype, feat_name)),
                                      node_feats[name],
                                      NodePartitionPolicy(self._gpb, ntype=ntype))
495
            for name in edge_feats:
496
497
498
499
500
                # The feature name has the following format: edge_type + "/" + feature_name.
                etype, feat_name = name.split('/')
                self._client.add_data(str(HeteroDataName(False, etype, feat_name)),
                                      edge_feats[name],
                                      EdgePartitionPolicy(self._gpb, etype=etype))
501
            self._client.map_shared_data(self._gpb)
502
            rpc.set_num_client(1)
503
        else:
504
            self._init()
505
506
507
508
509
            # Tell the backup servers to load the graph structure from shared memory.
            for server_id in range(self._client.num_servers):
                rpc.send_request(server_id, InitGraphRequest(graph_name))
            for server_id in range(self._client.num_servers):
                rpc.recv_response()
510
            self._client.barrier()
511

512
513
        self._ndata_store = {}
        self._edata_store = {}
514
515
516
        self._ndata = NodeDataView(self)
        self._edata = EdgeDataView(self)

Da Zheng's avatar
Da Zheng committed
517
518
519
520
521
522
        self._num_nodes = 0
        self._num_edges = 0
        for part_md in self._gpb.metadata():
            self._num_nodes += int(part_md['num_nodes'])
            self._num_edges += int(part_md['num_edges'])

523
524
525
526
        # When we store node/edge types in a list, they are stored in the order of type IDs.
        self._ntype_map = {ntype:i for i, ntype in enumerate(self.ntypes)}
        self._etype_map = {etype:i for i, etype in enumerate(self.etypes)}

527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
        # Get canonical edge types.
        # TODO(zhengda) this requires the server to store the graph with coo format.
        eid = []
        for etype in self.etypes:
            type_eid = F.zeros((1,), F.int64, F.cpu())
            eid.append(self._gpb.map_to_homo_eid(type_eid, etype))
        eid = F.cat(eid, 0)
        src, dst = dist_find_edges(self, eid)
        src_tids, _ = self._gpb.map_to_per_ntype(src)
        dst_tids, _ = self._gpb.map_to_per_ntype(dst)
        self._canonical_etypes = []
        etype_ids = F.arange(0, len(self.etypes))
        for src_tid, etype_id, dst_tid in zip(src_tids, etype_ids, dst_tids):
            src_tid = F.as_scalar(src_tid)
            etype_id = F.as_scalar(etype_id)
            dst_tid = F.as_scalar(dst_tid)
            self._canonical_etypes.append((self.ntypes[src_tid], self.etypes[etype_id],
                                           self.ntypes[dst_tid]))
545
546
547
548
549
550
        self._etype2canonical = {}
        for src_type, etype, dst_type in self._canonical_etypes:
            if etype in self._etype2canonical:
                self._etype2canonical[etype] = ()
            else:
                self._etype2canonical[etype] = (src_type, etype, dst_type)
551

552
553
    def _init(self):
        self._client = get_kvstore()
554
555
        assert self._client is not None, \
                'Distributed module is not initialized. Please call dgl.distributed.initialize.'
556
557
558
559
560
561
562
        self._g = _get_graph_from_shared_mem(self.graph_name)
        self._gpb = get_shared_mem_partition_book(self.graph_name, self._g)
        if self._gpb is None:
            self._gpb = self._gpb_input
        self._client.map_shared_data(self._gpb)

    def __getstate__(self):
563
        return self.graph_name, self._gpb, self._canonical_etypes
564
565

    def __setstate__(self, state):
566
        self.graph_name, self._gpb_input, self._canonical_etypes = state
567
568
        self._init()

569
570
571
572
573
574
        self._etype2canonical = {}
        for src_type, etype, dst_type in self._canonical_etypes:
            if etype in self._etype2canonical:
                self._etype2canonical[etype] = ()
            else:
                self._etype2canonical[etype] = (src_type, etype, dst_type)
575
576
        self._ndata_store = {}
        self._edata_store = {}
577
578
579
580
581
582
583
584
        self._ndata = NodeDataView(self)
        self._edata = EdgeDataView(self)
        self._num_nodes = 0
        self._num_edges = 0
        for part_md in self._gpb.metadata():
            self._num_nodes += int(part_md['num_nodes'])
            self._num_edges += int(part_md['num_edges'])

585
586
587
588
589
590
    @property
    def local_partition(self):
        ''' Return the local partition on the client

        DistGraph provides a global view of the distributed graph. Internally,
        it may contains a partition of the graph if it is co-located with
591
592
        the server. When servers and clients run on separate sets of machines,
        this returns None.
593
594
595

        Returns
        -------
596
        DGLGraph
597
598
599
600
            The local partition
        '''
        return self._g

601
602
603
604
605
606
607
608
609
610
611
612
    @property
    def nodes(self):
        '''Return a node view
        '''
        return HeteroNodeView(self)

    @property
    def edges(self):
        '''Return an edge view
        '''
        return HeteroEdgeView(self)

613
614
615
616
617
618
619
620
621
    @property
    def ndata(self):
        """Return the data view of all the nodes.

        Returns
        -------
        NodeDataView
            The data view in the distributed graph storage.
        """
622
        assert len(self.ntypes) == 1, "ndata only works for a graph with one node type."
623
624
625
626
627
628
629
630
631
632
633
        return self._ndata

    @property
    def edata(self):
        """Return the data view of all the edges.

        Returns
        -------
        EdgeDataView
            The data view in the distributed graph storage.
        """
634
        assert len(self.etypes) == 1, "edata only works for a graph with one edge type."
635
636
        return self._edata

637
638
639
640
641
642
643
644
645
646
647
648
649
650
    @property
    def idtype(self):
        """The dtype of graph index

        Returns
        -------
        backend dtype object
            th.int32/th.int64 or tf.int32/tf.int64 etc.

        See Also
        --------
        long
        int
        """
651
        # TODO(da?): describe when self._g is None and idtype shouldn't be called.
652
        return F.int64
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672

    @property
    def device(self):
        """Get the device context of this graph.

        Examples
        --------
        The following example uses PyTorch backend.

        >>> g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
        >>> print(g.device)
        device(type='cpu')
        >>> g = g.to('cuda:0')
        >>> print(g.device)
        device(type='cuda', index=0)

        Returns
        -------
        Device context object
        """
673
        # TODO(da?): describe when self._g is None and device shouldn't be called.
674
        return F.cpu()
675

676
677
678
679
680
681
682
683
684
685
686
    def is_pinned(self):
        """Check if the graph structure is pinned to the page-locked memory.

        Returns
        -------
        bool
            True if the graph structure is pinned.
        """
        # (Xin Yao): Currently we don't support pinning a DistGraph.
        return False

Da Zheng's avatar
Da Zheng committed
687
688
689
690
691
692
693
694
695
696
697
    @property
    def ntypes(self):
        """Return the list of node types of this graph.

        Returns
        -------
        list of str

        Examples
        --------

698
        >>> g = DistGraph("test")
Da Zheng's avatar
Da Zheng committed
699
700
701
        >>> g.ntypes
        ['_U']
        """
702
        return self._gpb.ntypes
Da Zheng's avatar
Da Zheng committed
703
704
705
706
707
708
709
710
711
712
713
714

    @property
    def etypes(self):
        """Return the list of edge types of this graph.

        Returns
        -------
        list of str

        Examples
        --------

715
        >>> g = DistGraph("test")
Da Zheng's avatar
Da Zheng committed
716
717
718
719
        >>> g.etypes
        ['_E']
        """
        # Currently, we only support a graph with one edge type.
720
721
        return self._gpb.etypes

722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
    @property
    def canonical_etypes(self):
        """Return all the canonical edge types in the graph.

        A canonical edge type is a string triplet ``(str, str, str)``
        for source node type, edge type and destination node type.

        Returns
        -------
        list[(str, str, str)]
            All the canonical edge type triplets in a list.

        Notes
        -----
        DGL internally assigns an integer ID for each edge type. The returned
        edge type names are sorted according to their IDs.

        See Also
        --------
        etypes

        Examples
        --------
        The following example uses PyTorch backend.

        >>> import dgl
        >>> import torch

        >>> g = DistGraph("test")
        >>> g.canonical_etypes
        [('user', 'follows', 'user'),
         ('user', 'follows', 'game'),
         ('user', 'plays', 'game')]
        """
        return self._canonical_etypes

758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
    def to_canonical_etype(self, etype):
        """Convert an edge type to the corresponding canonical edge type in the graph.

        A canonical edge type is a string triplet ``(str, str, str)``
        for source node type, edge type and destination node type.

        The function expects the given edge type name can uniquely identify a canonical edge
        type. DGL will raise error if this is not the case.

        Parameters
        ----------
        etype : str or (str, str, str)
            If :attr:`etype` is an edge type (str), it returns the corresponding canonical edge
            type in the graph. If :attr:`etype` is already a canonical edge type,
            it directly returns the input unchanged.

        Returns
        -------
        (str, str, str)
            The canonical edge type corresponding to the edge type.

        Examples
        --------
        The following example uses PyTorch backend.

        >>> import dgl
        >>> import torch

        >>> g = DistGraph("test")
        >>> g.canonical_etypes
        [('user', 'follows', 'user'),
         ('user', 'follows', 'game'),
         ('user', 'plays', 'game')]

        >>> g.to_canonical_etype('plays')
        ('user', 'plays', 'game')
        >>> g.to_canonical_etype(('user', 'plays', 'game'))
        ('user', 'plays', 'game')

        See Also
        --------
        canonical_etypes
        """
        if etype is None:
            if len(self.etypes) != 1:
                raise DGLError('Edge type name must be specified if there are more than one '
                               'edge types.')
            etype = self.etypes[0]
        if isinstance(etype, tuple):
            return etype
        else:
            ret = self._etype2canonical.get(etype, None)
            if ret is None:
                raise DGLError('Edge type "{}" does not exist.'.format(etype))
            if len(ret) != 3:
                raise DGLError('Edge type "{}" is ambiguous. Please use canonical edge type '
                               'in the form of (srctype, etype, dsttype)'.format(etype))
            return ret

817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
    def get_ntype_id(self, ntype):
        """Return the ID of the given node type.

        ntype can also be None. If so, there should be only one node type in the
        graph.

        Parameters
        ----------
        ntype : str
            Node type

        Returns
        -------
        int
        """
        if ntype is None:
            if len(self._ntype_map) != 1:
                raise DGLError('Node type name must be specified if there are more than one '
                               'node types.')
            return 0
        return self._ntype_map[ntype]

    def get_etype_id(self, etype):
        """Return the id of the given edge type.
Da Zheng's avatar
Da Zheng committed
841

842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
        etype can also be None. If so, there should be only one edge type in the
        graph.

        Parameters
        ----------
        etype : str or tuple of str
            Edge type

        Returns
        -------
        int
        """
        if etype is None:
            if len(self._etype_map) != 1:
                raise DGLError('Edge type name must be specified if there are more than one '
                               'edge types.')
            return 0
        return self._etype_map[etype]

    def number_of_nodes(self, ntype=None):
862
        """Alias of :func:`num_nodes`"""
863
        return self.num_nodes(ntype)
864

865
    def number_of_edges(self, etype=None):
866
        """Alias of :func:`num_edges`"""
867
        return self.num_edges(etype)
868

869
    def num_nodes(self, ntype=None):
870
871
        """Return the total number of nodes in the distributed graph.

872
873
874
875
876
877
        Parameters
        ----------
        ntype : str, optional
            The node type name. If given, it returns the number of nodes of the
            type. If not given (default), it returns the total number of nodes of all types.

878
879
880
881
882
883
884
885
        Returns
        -------
        int
            The number of nodes

        Examples
        --------
        >>> g = dgl.distributed.DistGraph('ogb-product')
886
        >>> print(g.num_nodes())
887
888
        2449029
        """
889
890
891
892
893
894
895
896
        if ntype is None:
            if len(self.ntypes) == 1:
                return self._gpb._num_nodes(self.ntypes[0])
            else:
                return sum([self._gpb._num_nodes(ntype) for ntype in self.ntypes])
        return self._gpb._num_nodes(ntype)

    def num_edges(self, etype=None):
897
898
        """Return the total number of edges in the distributed graph.

899
900
901
902
903
904
905
906
907
908
909
910
        Parameters
        ----------
        etype : str or (str, str, str), optional
            The type name of the edges. The allowed type name formats are:

            * ``(str, str, str)`` for source node type, edge type and destination node type.
            * or one ``str`` edge type name if the name can uniquely identify a
              triplet format in the graph.

            If not provided, return the total number of edges regardless of the types
            in the graph.

911
912
913
914
915
916
917
918
        Returns
        -------
        int
            The number of edges

        Examples
        --------
        >>> g = dgl.distributed.DistGraph('ogb-product')
919
        >>> print(g.num_edges())
920
921
        123718280
        """
922
923
924
925
926
927
        if etype is None:
            if len(self.etypes) == 1:
                return self._gpb._num_edges(self.etypes[0])
            else:
                return sum([self._gpb._num_edges(etype) for etype in self.etypes])
        return self._gpb._num_edges(etype)
928

929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
    def out_degrees(self, u=ALL):
        """Return the out-degree(s) of the given nodes.

        It computes the out-degree(s).
        It does not support heterogeneous graphs yet.

        Parameters
        ----------
        u : node IDs
            The node IDs. The allowed formats are:

            * ``int``: A single node.
            * Int Tensor: Each element is a node ID. The tensor must have the same device type
              and ID data type as the graph's.
            * iterable[int]: Each element is a node ID.

            If not given, return the in-degrees of all the nodes.

        Returns
        -------
        int or Tensor
            The out-degree(s) of the node(s) in a Tensor. The i-th element is the out-degree
            of the i-th input node. If :attr:`v` is an ``int``, return an ``int`` too.

        Examples
        --------
        The following example uses PyTorch backend.

        >>> import dgl
        >>> import torch

        Query for all nodes.

        >>> g.out_degrees()
        tensor([2, 2, 0, 0])

        Query for nodes 1 and 2.

        >>> g.out_degrees(torch.tensor([1, 2]))
        tensor([2, 0])

        See Also
        --------
        in_degrees
        """
        if is_all(u):
            u = F.arange(0, self.number_of_nodes())
        return dist_out_degrees(self, u)

    def in_degrees(self, v=ALL):
        """Return the in-degree(s) of the given nodes.

        It computes the in-degree(s).
        It does not support heterogeneous graphs yet.

        Parameters
        ----------
        v : node IDs
            The node IDs. The allowed formats are:

            * ``int``: A single node.
            * Int Tensor: Each element is a node ID. The tensor must have the same device type
              and ID data type as the graph's.
            * iterable[int]: Each element is a node ID.

            If not given, return the in-degrees of all the nodes.

        Returns
        -------
        int or Tensor
            The in-degree(s) of the node(s) in a Tensor. The i-th element is the in-degree
            of the i-th input node. If :attr:`v` is an ``int``, return an ``int`` too.

        Examples
        --------
        The following example uses PyTorch backend.

        >>> import dgl
        >>> import torch

        Query for all nodes.

        >>> g.in_degrees()
        tensor([0, 2, 1, 1])

        Query for nodes 1 and 2.

        >>> g.in_degrees(torch.tensor([1, 2]))
        tensor([2, 1])

        See Also
        --------
        out_degrees
        """
        if is_all(v):
            v = F.arange(0, self.number_of_nodes())
        return dist_in_degrees(self, v)

1027
    def node_attr_schemes(self):
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
        """Return the node feature schemes.

        Each feature scheme is a named tuple that stores the shape and data type
        of the node feature.

        Returns
        -------
        dict of str to schemes
            The schemes of node feature columns.

        Examples
        --------
        The following uses PyTorch backend.

        >>> g.node_attr_schemes()
        {'h': Scheme(shape=(4,), dtype=torch.float32)}

        See Also
        --------
        edge_attr_schemes
        """
1049
1050
1051
1052
1053
1054
        schemes = {}
        for key in self.ndata:
            schemes[key] = infer_scheme(self.ndata[key])
        return schemes

    def edge_attr_schemes(self):
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
        """Return the edge feature schemes.

        Each feature scheme is a named tuple that stores the shape and data type
        of the edge feature.

        Returns
        -------
        dict of str to schemes
            The schemes of edge feature columns.

        Examples
        --------
        The following uses PyTorch backend.

        >>> g.edge_attr_schemes()
        {'h': Scheme(shape=(4,), dtype=torch.float32)}

        See Also
        --------
        node_attr_schemes
        """
1076
1077
1078
1079
1080
1081
        schemes = {}
        for key in self.edata:
            schemes[key] = infer_scheme(self.edata[key])
        return schemes

    def rank(self):
1082
1083
1084
1085
        ''' The rank of the current DistGraph.

        This returns a unique number to identify the DistGraph object among all of
        the client processes.
1086
1087
1088
1089

        Returns
        -------
        int
1090
            The rank of the current DistGraph.
1091
        '''
Da Zheng's avatar
Da Zheng committed
1092
        return role.get_global_rank()
1093

1094
    def find_edges(self, edges, etype=None):
1095
1096
1097
1098
1099
1100
        """ Given an edge ID array, return the source
        and destination node ID array ``s`` and ``d``.  ``s[i]`` and ``d[i]``
        are source and destination node ID for edge ``eid[i]``.

        Parameters
        ----------
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
        edges : Int Tensor
            Each element is an ID. The tensor must have the same device type
              and ID data type as the graph's.

        etype : str or (str, str, str), optional
            The type names of the edges. The allowed type name formats are:

            * ``(str, str, str)`` for source node type, edge type and destination node type.
            * or one ``str`` edge type name if the name can uniquely identify a
              triplet format in the graph.

            Can be omitted if the graph has only one type of edges.
1113
1114
1115
1116
1117
1118
1119
1120

        Returns
        -------
        tensor
            The source node ID array.
        tensor
            The destination node ID array.
        """
1121
1122
1123
1124
1125
        if etype is None:
            assert len(self.etypes) == 1, 'find_edges requires etype for heterogeneous graphs.'

        gpb = self.get_partition_book()
        if len(gpb.etypes) > 1:
1126
1127
1128
            # if etype is a canonical edge type (str, str, str), extract the edge type
            if len(etype) == 3:
                etype = etype[1]
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
            edges = gpb.map_to_homo_eid(edges, etype)
        src, dst = dist_find_edges(self, edges)
        if len(gpb.ntypes) > 1:
            _, src = gpb.map_to_per_ntype(src)
            _, dst = gpb.map_to_per_ntype(dst)
        return src, dst

    def edge_subgraph(self, edges, relabel_nodes=True, store_ids=True):
        """Return a subgraph induced on the given edges.

        An edge-induced subgraph is equivalent to creating a new graph using the given
        edges. In addition to extracting the subgraph, DGL also copies the features
        of the extracted nodes and edges to the resulting graph. The copy is *lazy*
        and incurs data movement only when needed.

        If the graph is heterogeneous, DGL extracts a subgraph per relation and composes
        them as the resulting graph. Thus, the resulting graph has the same set of relations
        as the input one.

        Parameters
        ----------
        edges : Int Tensor or dict[(str, str, str), Int Tensor]
            The edges to form the subgraph. Each element is an edge ID. The tensor must have
            the same device type and ID data type as the graph's.

            If the graph is homogeneous, one can directly pass an Int Tensor.
            Otherwise, the argument must be a dictionary with keys being edge types
            and values being the edge IDs in the above formats.
        relabel_nodes : bool, optional
            If True, it will remove the isolated nodes and relabel the incident nodes in the
            extracted subgraph.
        store_ids : bool, optional
            If True, it will store the raw IDs of the extracted edges in the ``edata`` of the
            resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will
            also store the raw IDs of the incident nodes in the ``ndata`` of the resulting
            graph under name ``dgl.NID``.

        Returns
        -------
        G : DGLGraph
            The subgraph.
        """
        if isinstance(edges, dict):
            # TODO(zhengda) we need to directly generate subgraph of all relations with
            # one invocation.
            if isinstance(edges, tuple):
                subg = {etype: self.find_edges(edges[etype], etype[1]) for etype in edges}
            else:
                subg = {}
                for etype in edges:
                    assert len(self._etype2canonical[etype]) == 3, \
                            'the etype in input edges is ambiguous'
                    subg[self._etype2canonical[etype]] = self.find_edges(edges[etype], etype)
            num_nodes = {ntype: self.number_of_nodes(ntype) for ntype in self.ntypes}
            subg = dgl_heterograph(subg, num_nodes_dict=num_nodes)
1184
1185
            for etype in edges:
                subg.edges[etype].data[EID] = edges[etype]
1186
1187
1188
1189
        else:
            assert len(self.etypes) == 1
            subg = self.find_edges(edges)
            subg = dgl_graph(subg, num_nodes=self.number_of_nodes())
1190
            subg.edata[EID] = edges
1191
1192
1193
1194
1195

        if relabel_nodes:
            subg = compact_graphs(subg)
        assert store_ids, 'edge_subgraph always stores original node/edge IDs.'
        return subg
1196

1197
1198
1199
1200
1201
1202
    def get_partition_book(self):
        """Get the partition information.

        Returns
        -------
        GraphPartitionBook
1203
            Object that stores all graph partition information.
1204
1205
        """
        return self._gpb
1206

1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
    def get_node_partition_policy(self, ntype):
        """Get the partition policy for a node type.

        When creating a new distributed tensor, we need to provide a partition policy
        that indicates how to distribute data of the distributed tensor in a cluster
        of machines. When we load a distributed graph in the cluster, we have pre-defined
        partition policies for each node type and each edge type. By providing
        the node type, we can reference to the pre-defined partition policy for the node type.

        Parameters
        ----------
        ntype : str
            The node type

        Returns
        -------
        PartitionPolicy
            The partition policy for the node type.
        """
        return NodePartitionPolicy(self.get_partition_book(), ntype)

    def get_edge_partition_policy(self, etype):
        """Get the partition policy for an edge type.

        When creating a new distributed tensor, we need to provide a partition policy
        that indicates how to distribute data of the distributed tensor in a cluster
        of machines. When we load a distributed graph in the cluster, we have pre-defined
        partition policies for each node type and each edge type. By providing
        the edge type, we can reference to the pre-defined partition policy for the edge type.

        Parameters
        ----------
        etype : str
            The edge type

        Returns
        -------
        PartitionPolicy
            The partition policy for the edge type.
        """
        return EdgePartitionPolicy(self.get_partition_book(), etype)

1249
1250
1251
    def barrier(self):
        '''Barrier for all client nodes.

1252
1253
        This API blocks the current process untill all the clients invoke this API.
        Please use this API with caution.
1254
1255
1256
        '''
        self._client.barrier()

1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
    def sample_neighbors(self, seed_nodes, fanout, edge_dir='in', prob=None,
                         exclude_edges=None, replace=False,
                         output_device=None):
        # pylint: disable=unused-argument
        """Sample neighbors from a distributed graph."""
        # Currently prob, exclude_edges, output_device, and edge_dir are ignored.
        if len(self.etypes) > 1:
            frontier = graph_services.sample_etype_neighbors(
                self, seed_nodes, ETYPE, fanout, replace=replace)
        else:
            frontier = graph_services.sample_neighbors(
                self, seed_nodes, fanout, replace=replace)
        return frontier

1271
    def _get_ndata_names(self, ntype=None):
1272
1273
        ''' Get the names of all node data.
        '''
1274
        names = self._client.gdata_name_list()
1275
1276
        ndata_names = []
        for name in names:
1277
1278
1279
1280
            name = parse_hetero_data_name(name)
            right_type = (name.get_type() == ntype) if ntype is not None else True
            if name.is_node() and right_type:
                ndata_names.append(name)
1281
1282
        return ndata_names

1283
    def _get_edata_names(self, etype=None):
1284
1285
        ''' Get the names of all edge data.
        '''
1286
        names = self._client.gdata_name_list()
1287
1288
        edata_names = []
        for name in names:
1289
1290
1291
1292
            name = parse_hetero_data_name(name)
            right_type = (name.get_type() == etype) if etype is not None else True
            if name.is_edge() and right_type:
                edata_names.append(name)
1293
        return edata_names
1294
1295

def _get_overlap(mask_arr, ids):
1296
    """ Select the IDs given a boolean mask array.
1297

1298
1299
1300
    The boolean mask array indicates all of the IDs to be selected. We want to
    find the overlap between the IDs selected by the boolean mask array and
    the ID array.
1301
1302
1303
1304
1305
1306

    Parameters
    ----------
    mask_arr : 1D tensor
        A boolean mask array.
    ids : 1D tensor
1307
        A vector with IDs.
1308
1309
1310
1311

    Returns
    -------
    1D tensor
1312
        The selected IDs.
1313
1314
1315
1316
1317
    """
    if isinstance(mask_arr, DistTensor):
        masks = mask_arr[ids]
        return F.boolean_mask(ids, masks)
    else:
1318
        masks = F.gather_row(F.tensor(mask_arr), ids)
1319
1320
        return F.boolean_mask(ids, masks)

1321
1322
1323
def _split_local(partition_book, rank, elements, local_eles):
    ''' Split the input element list with respect to data locality.
    '''
Da Zheng's avatar
Da Zheng committed
1324
    num_clients = role.get_num_trainers()
1325
1326
    num_client_per_part = num_clients // partition_book.num_partitions()
    if rank is None:
Da Zheng's avatar
Da Zheng committed
1327
1328
1329
        rank = role.get_trainer_rank()
    assert rank < num_clients, \
            'The input rank ({}) is incorrect. #Trainers: {}'.format(rank, num_clients)
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
    # all ranks of the clients in the same machine are in a contiguous range.
    client_id_in_part = rank  % num_client_per_part
    local_eles = _get_overlap(elements, local_eles)

    # get a subset for the local client.
    size = len(local_eles) // num_client_per_part
    # if this isn't the last client in the partition.
    if client_id_in_part + 1 < num_client_per_part:
        return local_eles[(size * client_id_in_part):(size * (client_id_in_part + 1))]
    else:
        return local_eles[(size * client_id_in_part):]

1342
1343
1344
def _even_offset(n, k):
    ''' Split an array of length n into k segments and the difference of thier length is
        at most 1. Return the offset of each segment.
1345
    '''
1346
1347
1348
1349
    eles_per_part = n // k
    offset = np.array([0] + [eles_per_part] * k, dtype=int)
    offset[1 : n - eles_per_part * k + 1] += 1
    return np.cumsum(offset)
1350

1351
1352
1353
def _split_even_to_part(partition_book, elements):
    ''' Split the input element list evenly.
    '''
1354
1355
1356
1357
    # here we divide the element list as evenly as possible. If we use range partitioning,
    # the split results also respect the data locality. Range partitioning is the default
    # strategy.
    # TODO(zhengda) we need another way to divide the list for other partitioning strategy.
1358
    if isinstance(elements, DistTensor):
1359
        nonzero_count = elements.count_nonzero()
1360
    else:
1361
1362
        elements = F.tensor(elements)
        nonzero_count = F.count_nonzero(elements)
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
    # compute the offset of each split and ensure that the difference of each partition size
    # is 1.
    offsets = _even_offset(nonzero_count, partition_book.num_partitions())
    assert offsets[-1] == nonzero_count

    # Get the elements that belong to the partition.
    partid = partition_book.partid
    left, right = offsets[partid], offsets[partid + 1]

    x = y = 0
    num_elements = len(elements)
    block_size = num_elements // partition_book.num_partitions()
    part_eles = None
    # compute the nonzero tensor of each partition instead of whole tensor to save memory
    for idx in range(0, num_elements, block_size):
        nonzero_block = F.nonzero_1d(elements[idx:min(idx+block_size, num_elements)])
        x = y
        y += len(nonzero_block)
        if y > left and x < right:
            start = max(x, left) - x
            end = min(y, right) - x
            tmp = nonzero_block[start:end] + idx
            if part_eles is None:
                part_eles = tmp
            else:
                part_eles = F.cat((part_eles, tmp), 0)
        elif x >= right:
            break
1391
1392

    return part_eles
1393

1394
1395
1396
1397
def _split_random_within_part(partition_book, rank, part_eles):
    # If there are more than one client in a partition, we need to randomly select a subset of
    # elements in the partition for a client. We have to make sure that the set of elements
    # for different clients are disjoint.
1398

1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
    num_clients = role.get_num_trainers()
    num_client_per_part = num_clients // partition_book.num_partitions()
    if num_client_per_part == 1:
        return part_eles
    if rank is None:
        rank = role.get_trainer_rank()
    assert rank < num_clients, \
            'The input rank ({}) is incorrect. #Trainers: {}'.format(rank, num_clients)
    client_id_in_part = rank  % num_client_per_part
    offset = _even_offset(len(part_eles), num_client_per_part)

    # We set the random seed for each partition, so that each process (client) in a partition
    # permute the elements in a partition in the same way, so each process gets a disjoint subset
    # of elements.
    np.random.seed(partition_book.partid)
    rand_idx = np.random.permutation(len(part_eles))
    rand_idx = rand_idx[offset[client_id_in_part] : offset[client_id_in_part + 1]]
    idx, _ = F.sort_1d(F.tensor(rand_idx))
    return F.gather_row(part_eles, idx)

def _split_by_trainer_id(partition_book, part_eles, trainer_id,
                         num_client_per_part, client_id_in_part):
    # TODO(zhengda): MXNet cannot deal with empty tensors, which makes the implementation
    # much more difficult. Let's just use numpy for the computation for now. We just
    # perform operations on vectors. It shouldn't be too difficult.
    trainer_id = F.asnumpy(trainer_id)
    part_eles = F.asnumpy(part_eles)
    part_id = trainer_id // num_client_per_part
    trainer_id = trainer_id % num_client_per_part
    local_eles = part_eles[np.nonzero(part_id[part_eles] == partition_book.partid)[0]]
    # these are the Ids of the local elements in the partition. The Ids are global Ids.
    remote_eles = part_eles[np.nonzero(part_id[part_eles] != partition_book.partid)[0]]
    # these are the Ids of the remote nodes in the partition. The Ids are global Ids.
    local_eles_idx = np.concatenate(
        [np.nonzero(trainer_id[local_eles] == i)[0] for i in range(num_client_per_part)],
        # trainer_id[local_eles] is the trainer ids of local nodes in the partition and we
        # pick out the indices where the node belongs to each trainer i respectively, and
        # concatenate them.
        axis=0
    )
    # `local_eles_idx` is used to sort `local_eles` according to `trainer_id`. It is a
    # permutation of 0...(len(local_eles)-1)
    local_eles = local_eles[local_eles_idx]

    # evenly split local nodes to trainers
    local_offsets = _even_offset(len(local_eles), num_client_per_part)
    # evenly split remote nodes to trainers
    remote_offsets = _even_offset(len(remote_eles), num_client_per_part)

    client_local_eles = local_eles[
        local_offsets[client_id_in_part]:local_offsets[client_id_in_part + 1]]
    client_remote_eles = remote_eles[
        remote_offsets[client_id_in_part]:remote_offsets[client_id_in_part + 1]]
    client_eles = np.concatenate([client_local_eles, client_remote_eles], axis=0)
    return F.tensor(client_eles)

def node_split(nodes, partition_book=None, ntype='_N', rank=None, force_even=True,
               node_trainer_ids=None):
1457
1458
1459
1460
1461
1462
    ''' Split nodes and return a subset for the local rank.

    This function splits the input nodes based on the partition book and
    returns a subset of nodes for the local rank. This method is used for
    dividing workloads for distributed training.

1463
    The input nodes are stored as a vector of masks. The length of the vector is
1464
1465
1466
    the same as the number of nodes in a graph; 1 indicates that the vertex in
    the corresponding location exists.

1467
1468
    There are two strategies to split the nodes. By default, it splits the nodes
    in a way to maximize data locality. That is, all nodes that belong to a process
1469
    are returned. If ``force_even`` is set to true, the nodes are split evenly so
1470
1471
    that each process gets almost the same number of nodes.

1472
    When ``force_even`` is True, the data locality is still preserved if a graph is partitioned
1473
    with Metis and the node/edge IDs are shuffled.
1474
    In this case, majority of the nodes returned for a process are the ones that
1475
    belong to the process. If node/edge IDs are not shuffled, data locality is not guaranteed.
1476

1477
1478
1479
1480
    Parameters
    ----------
    nodes : 1D tensor or DistTensor
        A boolean mask vector that indicates input nodes.
1481
    partition_book : GraphPartitionBook, optional
1482
        The graph partition book
1483
1484
1485
    ntype : str, optional
        The node type of the input nodes.
    rank : int, optional
1486
        The rank of a process. If not given, the rank of the current process is used.
1487
    force_even : bool, optional
1488
        Force the nodes are split evenly.
1489
1490
1491
    node_trainer_ids : 1D tensor or DistTensor, optional
        If not None, split the nodes to the trainers on the same machine according to
        trainer IDs assigned to each node. Otherwise, split randomly.
1492
1493
1494
1495

    Returns
    -------
    1D-tensor
1496
        The vector of node IDs that belong to the rank.
1497
    '''
Da Zheng's avatar
Da Zheng committed
1498
1499
1500
1501
    if not isinstance(nodes, DistTensor):
        assert partition_book is not None, 'Regular tensor requires a partition book.'
    elif partition_book is None:
        partition_book = nodes.part_policy.partition_book
1502
1503

    assert len(nodes) == partition_book._num_nodes(ntype), \
1504
            'The length of boolean mask vector should be the number of nodes in the graph.'
1505
1506
    if rank is None:
        rank = role.get_trainer_rank()
1507
    if force_even:
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
        num_clients = role.get_num_trainers()
        num_client_per_part = num_clients // partition_book.num_partitions()
        assert num_clients % partition_book.num_partitions() == 0, \
                'The total number of clients should be multiple of the number of partitions.'
        part_nid = _split_even_to_part(partition_book, nodes)
        if num_client_per_part == 1:
            return part_nid
        elif node_trainer_ids is None:
            return _split_random_within_part(partition_book, rank, part_nid)
        else:
            trainer_id = node_trainer_ids[0:len(node_trainer_ids)]
            max_trainer_id = F.as_scalar(F.reduce_max(trainer_id)) + 1

            if max_trainer_id > num_clients:
                # We hope the partition scheme with trainer_id could be used when the number of
                # trainers is less than the `num_trainers_per_machine` previously assigned during
                # partitioning.
                assert max_trainer_id % num_clients == 0
                trainer_id //= (max_trainer_id // num_clients)

            client_id_in_part = rank % num_client_per_part
            return _split_by_trainer_id(partition_book, part_nid, trainer_id,
                                        num_client_per_part, client_id_in_part)
1531
1532
    else:
        # Get all nodes that belong to the rank.
1533
        local_nids = partition_book.partid2nids(partition_book.partid, ntype=ntype)
1534
        return _split_local(partition_book, rank, nodes, local_nids)
1535

1536
1537
def edge_split(edges, partition_book=None, etype='_E', rank=None, force_even=True,
               edge_trainer_ids=None):
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
    ''' Split edges and return a subset for the local rank.

    This function splits the input edges based on the partition book and
    returns a subset of edges for the local rank. This method is used for
    dividing workloads for distributed training.

    The input edges can be stored as a vector of masks. The length of the vector is
    the same as the number of edges in a graph; 1 indicates that the edge in
    the corresponding location exists.

1548
1549
    There are two strategies to split the edges. By default, it splits the edges
    in a way to maximize data locality. That is, all edges that belong to a process
1550
    are returned. If ``force_even`` is set to true, the edges are split evenly so
1551
1552
    that each process gets almost the same number of edges.

1553
    When ``force_even`` is True, the data locality is still preserved if a graph is partitioned
1554
1555
1556
    with Metis and the node/edge IDs are shuffled.
    In this case, majority of the nodes returned for a process are the ones that
    belong to the process. If node/edge IDs are not shuffled, data locality is not guaranteed.
1557

1558
1559
1560
    Parameters
    ----------
    edges : 1D tensor or DistTensor
1561
        A boolean mask vector that indicates input edges.
1562
    partition_book : GraphPartitionBook, optional
1563
        The graph partition book
1564
1565
1566
    etype : str, optional
        The edge type of the input edges.
    rank : int, optional
1567
        The rank of a process. If not given, the rank of the current process is used.
1568
    force_even : bool, optional
1569
        Force the edges are split evenly.
1570
1571
1572
    edge_trainer_ids : 1D tensor or DistTensor, optional
        If not None, split the edges to the trainers on the same machine according to
        trainer IDs assigned to each edge. Otherwise, split randomly.
1573
1574
1575
1576

    Returns
    -------
    1D-tensor
1577
        The vector of edge IDs that belong to the rank.
1578
    '''
Da Zheng's avatar
Da Zheng committed
1579
1580
1581
1582
    if not isinstance(edges, DistTensor):
        assert partition_book is not None, 'Regular tensor requires a partition book.'
    elif partition_book is None:
        partition_book = edges.part_policy.partition_book
1583
    assert len(edges) == partition_book._num_edges(etype), \
1584
            'The length of boolean mask vector should be the number of edges in the graph.'
1585
1586
    if rank is None:
        rank = role.get_trainer_rank()
1587
    if force_even:
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
        num_clients = role.get_num_trainers()
        num_client_per_part = num_clients // partition_book.num_partitions()
        assert num_clients % partition_book.num_partitions() == 0, \
                'The total number of clients should be multiple of the number of partitions.'
        part_eid = _split_even_to_part(partition_book, edges)
        if num_client_per_part == 1:
            return part_eid
        elif edge_trainer_ids is None:
            return _split_random_within_part(partition_book, rank, part_eid)
        else:
            trainer_id = edge_trainer_ids[0:len(edge_trainer_ids)]
            max_trainer_id = F.as_scalar(F.reduce_max(trainer_id)) + 1

            if max_trainer_id > num_clients:
                # We hope the partition scheme with trainer_id could be used when the number of
                # trainers is less than the `num_trainers_per_machine` previously assigned during
                # partitioning.
                assert max_trainer_id % num_clients == 0
                trainer_id //= (max_trainer_id // num_clients)

            client_id_in_part = rank % num_client_per_part
            return _split_by_trainer_id(partition_book, part_eid, trainer_id,
                                        num_client_per_part, client_id_in_part)
1611
1612
    else:
        # Get all edges that belong to the rank.
1613
        local_eids = partition_book.partid2eids(partition_book.partid, etype=etype)
1614
        return _split_local(partition_book, rank, edges, local_eids)
1615
1616

rpc.register_service(INIT_GRAPH, InitGraphRequest, InitGraphResponse)