"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "eeae0338e7ad2b3749eac0c8701ec250a1884844"
Unverified Commit fc7cd275 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[RPC] accelerate sampling RPC (#1669)



* support null ndarray

* speed up

* placeholder

* fix

* lint

* lint

* fix

* done

* comment

* Update sampling.py

* Update sampling.py

* Update sampling.py
Co-authored-by: default avatarDa Zheng <zhengda1936@gmail.com>
parent 071cba1f
......@@ -399,6 +399,11 @@ def serialize_to_payload(serializable):
data = bytearray(pickle.dumps((nonarray_pos, nonarray_state)))
return data, array_state
class PlaceHolder:
"""PlaceHolder object for deserialization"""
_PLACEHOLDER = PlaceHolder()
def deserialize_from_payload(cls, data, tensors):
"""Deserialize and reconstruct the object from payload.
......@@ -419,14 +424,15 @@ def deserialize_from_payload(cls, data, tensors):
De-serialized object of class cls.
"""
pos, nonarray_state = pickle.loads(data)
state = [None] * (len(nonarray_state) + len(tensors))
# Use _PLACEHOLDER to distinguish with other deserizliaed elements
state = [_PLACEHOLDER] * (len(nonarray_state) + len(tensors))
for i, no_state in zip(pos, nonarray_state):
state[i] = no_state
if len(tensors) != 0:
j = 0
state_len = len(state)
for i in range(state_len):
if state[i] is None:
if state[i] is _PLACEHOLDER:
state[i] = tensors[j]
j += 1
if len(state) == 1:
......
......@@ -4,6 +4,7 @@ from ..sampling import sample_neighbors as local_sample_neighbors
from . import register_service
from ..convert import graph
from ..base import NID, EID
from ..utils import toindex
from .. import backend as F
__all__ = ['sample_neighbors']
......@@ -46,7 +47,7 @@ class SamplingRequest(Request):
local_g = server_state.graph
partition_book = server_state.partition_book
local_ids = F.astype(partition_book.nid2localnid(
F.tensor(self.seed_nodes), partition_book.partid), local_g.idtype)
self.seed_nodes, partition_book.partid), local_g.idtype)
# local_ids = self.seed_nodes
sampled_graph = local_sample_neighbors(
local_g, local_ids, self.fan_out, self.edge_dir, self.prob, self.replace)
......@@ -55,7 +56,6 @@ class SamplingRequest(Request):
global_src, global_dst = global_nid_mapping[src], global_nid_mapping[dst]
global_eids = F.gather_row(
local_g.edata[EID], sampled_graph.edata[EID])
res = SamplingResponse(global_src, global_dst, global_eids)
return res
......@@ -83,17 +83,13 @@ def sample_neighbors(dist_graph, nodes, fanout, edge_dir='in', prob=None, replac
assert edge_dir == 'in'
req_list = []
partition_book = dist_graph.get_partition_book()
partition_id = F.asnumpy(
partition_book.nid2partid(F.tensor(nodes))).tolist()
node_id_per_partition = [[]
for _ in range(partition_book.num_partitions())]
for pid, node in zip(partition_id, nodes):
node_id_per_partition[pid].append(node)
for pid, node_id in enumerate(node_id_per_partition):
np_nodes = toindex(nodes).tousertensor()
partition_id = partition_book.nid2partid(np_nodes)
for pid in range(partition_book.num_partitions()):
node_id = F.boolean_mask(np_nodes, partition_id == pid)
if len(node_id) != 0:
req = SamplingRequest(
node_id, fanout, edge_dir=edge_dir, prob=prob, replace=replace)
req = SamplingRequest(node_id, fanout, edge_dir=edge_dir,
prob=prob, replace=replace)
req_list.append((pid, req))
res_list = remote_call_to_machine(req_list)
sampled_graph = merge_graphs(res_list, dist_graph.number_of_nodes())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment