sampling.py 6.99 KB
Newer Older
Jinjing Zhou's avatar
Jinjing Zhou committed
1
"""Sampling module"""
2
3
4
from collections import namedtuple

from .rpc import Request, Response, send_requests_to_machine, recv_responses
Jinjing Zhou's avatar
Jinjing Zhou committed
5
6
7
8
from ..sampling import sample_neighbors as local_sample_neighbors
from . import register_service
from ..convert import graph
from ..base import NID, EID
9
from ..utils import toindex
Jinjing Zhou's avatar
Jinjing Zhou committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from .. import backend as F

__all__ = ['sample_neighbors']

SAMPLING_SERVICE_ID = 6657


class SamplingResponse(Response):
    """Sampling Response"""

    def __init__(self, global_src, global_dst, global_eids):
        self.global_src = global_src
        self.global_dst = global_dst
        self.global_eids = global_eids

    def __setstate__(self, state):
        self.global_src, self.global_dst, self.global_eids = state

    def __getstate__(self):
        return self.global_src, self.global_dst, self.global_eids


32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def _sample_neighbors(local_g, partition_book, seed_nodes, fan_out, edge_dir, prob, replace):
    """ Sample from local partition.

    The input nodes use global Ids. We need to map the global node Ids to local node Ids,
    perform sampling and map the sampled results to the global Ids space again.
    The sampled results are stored in three vectors that store source nodes, destination nodes
    and edge Ids.
    """
    local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid)
    local_ids = F.astype(local_ids, local_g.idtype)
    # local_ids = self.seed_nodes
    sampled_graph = local_sample_neighbors(
        local_g, local_ids, fan_out, edge_dir, prob, replace)
    global_nid_mapping = local_g.ndata[NID]
    src, dst = sampled_graph.edges()
    global_src, global_dst = global_nid_mapping[src], global_nid_mapping[dst]
    global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
    return global_src, global_dst, global_eids


Jinjing Zhou's avatar
Jinjing Zhou committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class SamplingRequest(Request):
    """Sampling Request"""

    def __init__(self, nodes, fan_out, edge_dir='in', prob=None, replace=False):
        self.seed_nodes = nodes
        self.edge_dir = edge_dir
        self.prob = prob
        self.replace = replace
        self.fan_out = fan_out

    def __setstate__(self, state):
        self.seed_nodes, self.edge_dir, self.prob, self.replace, self.fan_out = state

    def __getstate__(self):
        return self.seed_nodes, self.edge_dir, self.prob, self.replace, self.fan_out

    def process_request(self, server_state):
        local_g = server_state.graph
        partition_book = server_state.partition_book
71
72
73
74
75
        global_src, global_dst, global_eids = _sample_neighbors(local_g, partition_book,
                                                                self.seed_nodes,
                                                                self.fan_out, self.edge_dir,
                                                                self.prob, self.replace)
        return SamplingResponse(global_src, global_dst, global_eids)
Jinjing Zhou's avatar
Jinjing Zhou committed
76
77
78
79


def merge_graphs(res_list, num_nodes):
    """Merge request from multiple servers"""
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    if len(res_list) > 1:
        srcs = []
        dsts = []
        eids = []
        for res in res_list:
            srcs.append(res.global_src)
            dsts.append(res.global_dst)
            eids.append(res.global_eids)
        src_tensor = F.cat(srcs, 0)
        dst_tensor = F.cat(dsts, 0)
        eid_tensor = F.cat(eids, 0)
    else:
        src_tensor = res_list[0].global_src
        dst_tensor = res_list[0].global_dst
        eid_tensor = res_list[0].global_eids
Jinjing Zhou's avatar
Jinjing Zhou committed
95
96
97
98
99
    g = graph((src_tensor, dst_tensor),
              restrict_format='coo', num_nodes=num_nodes)
    g.edata[EID] = eid_tensor
    return g

100
LocalSampledGraph = namedtuple('LocalSampledGraph', 'global_src global_dst global_eids')
Jinjing Zhou's avatar
Jinjing Zhou committed
101
102

def sample_neighbors(dist_graph, nodes, fanout, edge_dir='in', prob=None, replace=False):
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    """Sample from the neighbors of the given nodes from a distributed graph.

    When sampling with replacement, the sampled subgraph could have parallel edges.

    For sampling without replace, if fanout > the number of neighbors, all the
    neighbors are sampled.

    Node/edge features are not preserved. The original IDs of
    the sampled edges are stored as the `dgl.EID` feature in the returned graph.

    Parameters
    ----------
    g : DistGraph
        The distributed graph.
    nodes : tensor or dict
        Node ids to sample neighbors from. The allowed types
        are dictionary of node types to node id tensors, or simply node id tensor if
        the given graph g has only one type of nodes.
    fanout : int or dict[etype, int]
        The number of sampled neighbors for each node on each edge type. Provide a dict
        to specify different fanout values for each edge type.
    edge_dir : str, optional
        Edge direction ('in' or 'out'). If is 'in', sample from in edges. Otherwise,
        sample from out edges.
    prob : str, optional
        Feature name used as the probabilities associated with each neighbor of a node.
        Its shape should be compatible with a scalar edge feature tensor.
    replace : bool, optional
        If True, sample with replacement.

    Returns
    -------
    DGLHeteroGraph
        A sampled subgraph containing only the sampled neighbor edges from
        ``nodes``. The sampled subgraph has the same metagraph as the original
        one.
    """
Jinjing Zhou's avatar
Jinjing Zhou committed
140
141
142
    assert edge_dir == 'in'
    req_list = []
    partition_book = dist_graph.get_partition_book()
143
144
145
    nodes = toindex(nodes).tousertensor()
    partition_id = partition_book.nid2partid(nodes)
    local_nids = None
146
    for pid in range(partition_book.num_partitions()):
147
148
149
150
151
152
153
154
155
        node_id = F.boolean_mask(nodes, partition_id == pid)
        # We optimize the sampling on a local partition if the server and the client
        # run on the same machine. With a good partitioning, most of the seed nodes
        # should reside in the local partition. If the server and the client
        # are not co-located, the client doesn't have a local partition.
        if pid == partition_book.partid and dist_graph.local_partition is not None:
            assert local_nids is None
            local_nids = node_id
        elif len(node_id) != 0:
156
157
            req = SamplingRequest(node_id, fanout, edge_dir=edge_dir,
                                  prob=prob, replace=replace)
Jinjing Zhou's avatar
Jinjing Zhou committed
158
            req_list.append((pid, req))
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176

    # send requests to the remote machine.
    msgseq2pos = None
    if len(req_list) > 0:
        msgseq2pos = send_requests_to_machine(req_list)

    # sample neighbors for the nodes in the local partition.
    res_list = []
    if local_nids is not None:
        src, dst, eids = _sample_neighbors(dist_graph.local_partition, partition_book,
                                           local_nids, fanout, edge_dir, prob, replace)
        res_list.append(LocalSampledGraph(src, dst, eids))

    # receive responses from remote machines.
    if msgseq2pos is not None:
        results = recv_responses(msgseq2pos)
        res_list.extend(results)

Jinjing Zhou's avatar
Jinjing Zhou committed
177
178
179
180
181
    sampled_graph = merge_graphs(res_list, dist_graph.number_of_nodes())
    return sampled_graph


register_service(SAMPLING_SERVICE_ID, SamplingRequest, SamplingResponse)