"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "cef4f65cf78793d1826b78e1e0c0b6f45dbd6c68"
Unverified Commit fc7cd275 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[RPC] accelerate sampling RPC (#1669)



* support null ndarray

* speed up

* placeholder

* fix

* lint

* lint

* fix

* done

* comment

* Update sampling.py

* Update sampling.py

* Update sampling.py
Co-authored-by: default avatarDa Zheng <zhengda1936@gmail.com>
parent 071cba1f
...@@ -399,6 +399,11 @@ def serialize_to_payload(serializable): ...@@ -399,6 +399,11 @@ def serialize_to_payload(serializable):
data = bytearray(pickle.dumps((nonarray_pos, nonarray_state))) data = bytearray(pickle.dumps((nonarray_pos, nonarray_state)))
return data, array_state return data, array_state
class PlaceHolder:
"""PlaceHolder object for deserialization"""
_PLACEHOLDER = PlaceHolder()
def deserialize_from_payload(cls, data, tensors): def deserialize_from_payload(cls, data, tensors):
"""Deserialize and reconstruct the object from payload. """Deserialize and reconstruct the object from payload.
...@@ -419,14 +424,15 @@ def deserialize_from_payload(cls, data, tensors): ...@@ -419,14 +424,15 @@ def deserialize_from_payload(cls, data, tensors):
De-serialized object of class cls. De-serialized object of class cls.
""" """
pos, nonarray_state = pickle.loads(data) pos, nonarray_state = pickle.loads(data)
state = [None] * (len(nonarray_state) + len(tensors)) # Use _PLACEHOLDER to distinguish with other deserizliaed elements
state = [_PLACEHOLDER] * (len(nonarray_state) + len(tensors))
for i, no_state in zip(pos, nonarray_state): for i, no_state in zip(pos, nonarray_state):
state[i] = no_state state[i] = no_state
if len(tensors) != 0: if len(tensors) != 0:
j = 0 j = 0
state_len = len(state) state_len = len(state)
for i in range(state_len): for i in range(state_len):
if state[i] is None: if state[i] is _PLACEHOLDER:
state[i] = tensors[j] state[i] = tensors[j]
j += 1 j += 1
if len(state) == 1: if len(state) == 1:
......
...@@ -4,6 +4,7 @@ from ..sampling import sample_neighbors as local_sample_neighbors ...@@ -4,6 +4,7 @@ from ..sampling import sample_neighbors as local_sample_neighbors
from . import register_service from . import register_service
from ..convert import graph from ..convert import graph
from ..base import NID, EID from ..base import NID, EID
from ..utils import toindex
from .. import backend as F from .. import backend as F
__all__ = ['sample_neighbors'] __all__ = ['sample_neighbors']
...@@ -46,7 +47,7 @@ class SamplingRequest(Request): ...@@ -46,7 +47,7 @@ class SamplingRequest(Request):
local_g = server_state.graph local_g = server_state.graph
partition_book = server_state.partition_book partition_book = server_state.partition_book
local_ids = F.astype(partition_book.nid2localnid( local_ids = F.astype(partition_book.nid2localnid(
F.tensor(self.seed_nodes), partition_book.partid), local_g.idtype) self.seed_nodes, partition_book.partid), local_g.idtype)
# local_ids = self.seed_nodes # local_ids = self.seed_nodes
sampled_graph = local_sample_neighbors( sampled_graph = local_sample_neighbors(
local_g, local_ids, self.fan_out, self.edge_dir, self.prob, self.replace) local_g, local_ids, self.fan_out, self.edge_dir, self.prob, self.replace)
...@@ -55,7 +56,6 @@ class SamplingRequest(Request): ...@@ -55,7 +56,6 @@ class SamplingRequest(Request):
global_src, global_dst = global_nid_mapping[src], global_nid_mapping[dst] global_src, global_dst = global_nid_mapping[src], global_nid_mapping[dst]
global_eids = F.gather_row( global_eids = F.gather_row(
local_g.edata[EID], sampled_graph.edata[EID]) local_g.edata[EID], sampled_graph.edata[EID])
res = SamplingResponse(global_src, global_dst, global_eids) res = SamplingResponse(global_src, global_dst, global_eids)
return res return res
...@@ -83,17 +83,13 @@ def sample_neighbors(dist_graph, nodes, fanout, edge_dir='in', prob=None, replac ...@@ -83,17 +83,13 @@ def sample_neighbors(dist_graph, nodes, fanout, edge_dir='in', prob=None, replac
assert edge_dir == 'in' assert edge_dir == 'in'
req_list = [] req_list = []
partition_book = dist_graph.get_partition_book() partition_book = dist_graph.get_partition_book()
np_nodes = toindex(nodes).tousertensor()
partition_id = F.asnumpy( partition_id = partition_book.nid2partid(np_nodes)
partition_book.nid2partid(F.tensor(nodes))).tolist() for pid in range(partition_book.num_partitions()):
node_id_per_partition = [[] node_id = F.boolean_mask(np_nodes, partition_id == pid)
for _ in range(partition_book.num_partitions())]
for pid, node in zip(partition_id, nodes):
node_id_per_partition[pid].append(node)
for pid, node_id in enumerate(node_id_per_partition):
if len(node_id) != 0: if len(node_id) != 0:
req = SamplingRequest( req = SamplingRequest(node_id, fanout, edge_dir=edge_dir,
node_id, fanout, edge_dir=edge_dir, prob=prob, replace=replace) prob=prob, replace=replace)
req_list.append((pid, req)) req_list.append((pid, req))
res_list = remote_call_to_machine(req_list) res_list = remote_call_to_machine(req_list)
sampled_graph = merge_graphs(res_list, dist_graph.number_of_nodes()) sampled_graph = merge_graphs(res_list, dist_graph.number_of_nodes())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment