graph_services.py 11.1 KB
Newer Older
1
"""A set of graph services of getting subgraphs from DistGraph"""
2
3
4
from collections import namedtuple

from .rpc import Request, Response, send_requests_to_machine, recv_responses
Jinjing Zhou's avatar
Jinjing Zhou committed
5
from ..sampling import sample_neighbors as local_sample_neighbors
6
from ..transform import in_subgraph as local_in_subgraph
Jinjing Zhou's avatar
Jinjing Zhou committed
7
8
9
from . import register_service
from ..convert import graph
from ..base import NID, EID
10
from ..utils import toindex
Jinjing Zhou's avatar
Jinjing Zhou committed
11
12
from .. import backend as F

13
__all__ = ['sample_neighbors', 'in_subgraph']
Jinjing Zhou's avatar
Jinjing Zhou committed
14
15

SAMPLING_SERVICE_ID = 6657
16
INSUBGRAPH_SERVICE_ID = 6658
Jinjing Zhou's avatar
Jinjing Zhou committed
17

18
19
class SubgraphResponse(Response):
    """The response for sampling and in_subgraph"""
Jinjing Zhou's avatar
Jinjing Zhou committed
20
21
22
23
24
25
26
27
28
29
30
31
32

    def __init__(self, global_src, global_dst, global_eids):
        self.global_src = global_src
        self.global_dst = global_dst
        self.global_eids = global_eids

    def __setstate__(self, state):
        self.global_src, self.global_dst, self.global_eids = state

    def __getstate__(self):
        return self.global_src, self.global_dst, self.global_eids


33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def _sample_neighbors(local_g, partition_book, seed_nodes, fan_out, edge_dir, prob, replace):
    """ Sample from local partition.

    The input nodes use global Ids. We need to map the global node Ids to local node Ids,
    perform sampling and map the sampled results to the global Ids space again.
    The sampled results are stored in three vectors that store source nodes, destination nodes
    and edge Ids.
    """
    local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid)
    local_ids = F.astype(local_ids, local_g.idtype)
    # local_ids = self.seed_nodes
    sampled_graph = local_sample_neighbors(
        local_g, local_ids, fan_out, edge_dir, prob, replace)
    global_nid_mapping = local_g.ndata[NID]
    src, dst = sampled_graph.edges()
    global_src, global_dst = global_nid_mapping[src], global_nid_mapping[dst]
    global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
    return global_src, global_dst, global_eids


53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def _in_subgraph(local_g, partition_book, seed_nodes):
    """ Get in subgraph from local partition.

    The input nodes use global Ids. We need to map the global node Ids to local node Ids,
    get in-subgraph and map the sampled results to the global Ids space again.
    The results are stored in three vectors that store source nodes, destination nodes
    and edge Ids.
    """
    local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid)
    local_ids = F.astype(local_ids, local_g.idtype)
    # local_ids = self.seed_nodes
    sampled_graph = local_in_subgraph(local_g, local_ids)
    global_nid_mapping = local_g.ndata[NID]
    src, dst = sampled_graph.edges()
    global_src, global_dst = global_nid_mapping[src], global_nid_mapping[dst]
    global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
    return global_src, global_dst, global_eids


Jinjing Zhou's avatar
Jinjing Zhou committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
class SamplingRequest(Request):
    """Sampling Request"""

    def __init__(self, nodes, fan_out, edge_dir='in', prob=None, replace=False):
        self.seed_nodes = nodes
        self.edge_dir = edge_dir
        self.prob = prob
        self.replace = replace
        self.fan_out = fan_out

    def __setstate__(self, state):
        self.seed_nodes, self.edge_dir, self.prob, self.replace, self.fan_out = state

    def __getstate__(self):
        return self.seed_nodes, self.edge_dir, self.prob, self.replace, self.fan_out

    def process_request(self, server_state):
        local_g = server_state.graph
        partition_book = server_state.partition_book
91
92
93
94
        global_src, global_dst, global_eids = _sample_neighbors(local_g, partition_book,
                                                                self.seed_nodes,
                                                                self.fan_out, self.edge_dir,
                                                                self.prob, self.replace)
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        return SubgraphResponse(global_src, global_dst, global_eids)


class InSubgraphRequest(Request):
    """InSubgraph Request"""

    def __init__(self, nodes):
        self.seed_nodes = nodes

    def __setstate__(self, state):
        self.seed_nodes = state

    def __getstate__(self):
        return self.seed_nodes

    def process_request(self, server_state):
        local_g = server_state.graph
        partition_book = server_state.partition_book
        global_src, global_dst, global_eids = _in_subgraph(local_g, partition_book,
                                                           self.seed_nodes)
        return SubgraphResponse(global_src, global_dst, global_eids)
Jinjing Zhou's avatar
Jinjing Zhou committed
116
117
118
119


def merge_graphs(res_list, num_nodes):
    """Merge request from multiple servers"""
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    if len(res_list) > 1:
        srcs = []
        dsts = []
        eids = []
        for res in res_list:
            srcs.append(res.global_src)
            dsts.append(res.global_dst)
            eids.append(res.global_eids)
        src_tensor = F.cat(srcs, 0)
        dst_tensor = F.cat(dsts, 0)
        eid_tensor = F.cat(eids, 0)
    else:
        src_tensor = res_list[0].global_src
        dst_tensor = res_list[0].global_dst
        eid_tensor = res_list[0].global_eids
135
136
    g = graph((src_tensor, dst_tensor),
              restrict_format='coo', num_nodes=num_nodes)
Jinjing Zhou's avatar
Jinjing Zhou committed
137
138
139
    g.edata[EID] = eid_tensor
    return g

140
LocalSampledGraph = namedtuple('LocalSampledGraph', 'global_src global_dst global_eids')
Jinjing Zhou's avatar
Jinjing Zhou committed
141

142
143
def _distributed_access(g, nodes, issue_remote_req, local_access):
    '''A routine that fetches local neighborhood of nodes from the distributed graph.
144

145
146
147
148
149
    The local neighborhood of some nodes are stored in the local machine and the other
    nodes have their neighborhood on remote machines. This code will issue remote
    access requests first before fetching data from the local machine. In the end,
    we combine the data from the local machine and remote machines.
    In this way, we can hide the latency of accessing data on remote machines.
150
151
152
153

    Parameters
    ----------
    g : DistGraph
154
155
156
157
158
159
160
        The distributed graph
    nodes : tensor
        The nodes whose neighborhood are to be fetched.
    issue_remote_req : callable
        The function that issues requests to access remote data.
    local_access : callable
        The function that reads data on the local machine.
161
162
163
164

    Returns
    -------
    DGLHeteroGraph
165
166
        The subgraph that contains the neighborhoods of all input nodes.
    '''
Jinjing Zhou's avatar
Jinjing Zhou committed
167
    req_list = []
168
    partition_book = g.get_partition_book()
169
170
171
    nodes = toindex(nodes).tousertensor()
    partition_id = partition_book.nid2partid(nodes)
    local_nids = None
172
    for pid in range(partition_book.num_partitions()):
173
174
175
176
177
        node_id = F.boolean_mask(nodes, partition_id == pid)
        # We optimize the sampling on a local partition if the server and the client
        # run on the same machine. With a good partitioning, most of the seed nodes
        # should reside in the local partition. If the server and the client
        # are not co-located, the client doesn't have a local partition.
178
        if pid == partition_book.partid and g.local_partition is not None:
179
180
181
            assert local_nids is None
            local_nids = node_id
        elif len(node_id) != 0:
182
            req = issue_remote_req(node_id)
Jinjing Zhou's avatar
Jinjing Zhou committed
183
            req_list.append((pid, req))
184
185
186
187
188
189
190
191
192

    # send requests to the remote machine.
    msgseq2pos = None
    if len(req_list) > 0:
        msgseq2pos = send_requests_to_machine(req_list)

    # sample neighbors for the nodes in the local partition.
    res_list = []
    if local_nids is not None:
193
        src, dst, eids = local_access(g.local_partition, partition_book, local_nids)
194
195
196
197
198
199
200
        res_list.append(LocalSampledGraph(src, dst, eids))

    # receive responses from remote machines.
    if msgseq2pos is not None:
        results = recv_responses(msgseq2pos)
        res_list.extend(results)

201
    sampled_graph = merge_graphs(res_list, g.number_of_nodes())
Jinjing Zhou's avatar
Jinjing Zhou committed
202
203
    return sampled_graph

204
205
206
207
208
209
210
211
212
213
def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
    """Sample from the neighbors of the given nodes from a distributed graph.

    When sampling with replacement, the sampled subgraph could have parallel edges.

    For sampling without replace, if fanout > the number of neighbors, all the
    neighbors are sampled.

    Node/edge features are not preserved. The original IDs of
    the sampled edges are stored as the `dgl.EID` feature in the returned graph.
Jinjing Zhou's avatar
Jinjing Zhou committed
214

215
216
217
218
219
220
    For now, we only support the input graph with one node type and one edge type.

    Parameters
    ----------
    g : DistGraph
        The distributed graph.
Da Zheng's avatar
Da Zheng committed
221
222
223
    nodes : tensor or dict
        Node ids to sample neighbors from. If it's a dict, it should contain only
        one key-value pair to make this API consistent with dgl.sampling.sample_neighbors.
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    fanout : int
        The number of sampled neighbors for each node.
    edge_dir : str, optional
        Edge direction ('in' or 'out'). If is 'in', sample from in edges. Otherwise,
        sample from out edges.
    prob : str, optional
        Feature name used as the probabilities associated with each neighbor of a node.
        Its shape should be compatible with a scalar edge feature tensor.
    replace : bool, optional
        If True, sample with replacement.

    Returns
    -------
    DGLHeteroGraph
        A sampled subgraph containing only the sampled neighbor edges from
        ``nodes``. The sampled subgraph has the same metagraph as the original
        one.
    """
Da Zheng's avatar
Da Zheng committed
242
243
244
    if isinstance(nodes, dict):
        assert len(nodes) == 1, 'The distributed sampler only supports one node type for now.'
        nodes = list(nodes.values())[0]
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
    def issue_remote_req(node_ids):
        return SamplingRequest(node_ids, fanout, edge_dir=edge_dir,
                               prob=prob, replace=replace)
    def local_access(local_g, partition_book, local_nids):
        return _sample_neighbors(local_g, partition_book, local_nids,
                                 fanout, edge_dir, prob, replace)
    return _distributed_access(g, nodes, issue_remote_req, local_access)

def in_subgraph(g, nodes):
    """Extract the subgraph containing only the in edges of the given nodes.

    The subgraph keeps the same type schema and the cardinality of the original one.
    Node/edge features are not preserved. The original IDs
    the extracted edges are stored as the `dgl.EID` feature in the returned graph.

    For now, we only support the input graph with one node type and one edge type.

    Parameters
    ----------
    g : DistGraph
        The distributed graph structure.
    nodes : tensor
        Node ids to sample neighbors from.

    Returns
    -------
    DGLHeteroGraph
        The subgraph.
    """
Da Zheng's avatar
Da Zheng committed
274
275
276
    if isinstance(nodes, dict):
        assert len(nodes) == 1, 'The distributed in_subgraph only supports one node type for now.'
        nodes = list(nodes.values())[0]
277
278
279
280
281
282
283
284
    def issue_remote_req(node_ids):
        return InSubgraphRequest(node_ids)
    def local_access(local_g, partition_book, local_nids):
        return _in_subgraph(local_g, partition_book, local_nids)
    return _distributed_access(g, nodes, issue_remote_req, local_access)

register_service(SAMPLING_SERVICE_ID, SamplingRequest, SubgraphResponse)
register_service(INSUBGRAPH_SERVICE_ID, InSubgraphRequest, SubgraphResponse)