"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f82ebb9a03f004558bd198a65abce0121d74f3f2"
Unverified Commit c65d6fa5 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Dist] format dtypes when loading graph in server (#4228)

* [Dist] format dtypes when loading graph in server

* add test

* refine

* add comments
parent 1feec870
...@@ -336,6 +336,17 @@ class DistGraphServer(KVServer): ...@@ -336,6 +336,17 @@ class DistGraphServer(KVServer):
self.client_g, _, _, self.gpb, graph_name, \ self.client_g, _, _, self.gpb, graph_name, \
ntypes, etypes = load_partition(part_config, self.part_id, load_feats=False) ntypes, etypes = load_partition(part_config, self.part_id, load_feats=False)
print('load ' + graph_name) print('load ' + graph_name)
# formatting dtype
# TODO(Rui) Formatting forcely is not a perfect solution.
# We'd better store all dtypes when mapping to shared memory
# and map back with original dtypes.
for k, dtype in FIELD_DICT.items():
if k in self.client_g.ndata:
self.client_g.ndata[k] = F.astype(
self.client_g.ndata[k], dtype)
if k in self.client_g.edata:
self.client_g.edata[k] = F.astype(
self.client_g.edata[k], dtype)
# Create the graph formats specified the users. # Create the graph formats specified the users.
self.client_g = self.client_g.formats(graph_format) self.client_g = self.client_g.formats(graph_format)
self.client_g.create_formats_() self.client_g.create_formats_()
......
...@@ -35,6 +35,15 @@ def run_server(graph_name, server_id, server_count, num_clients, shared_mem, kee ...@@ -35,6 +35,15 @@ def run_server(graph_name, server_id, server_count, num_clients, shared_mem, kee
disable_shared_mem=not shared_mem, disable_shared_mem=not shared_mem,
graph_format=['csc', 'coo'], keep_alive=keep_alive) graph_format=['csc', 'coo'], keep_alive=keep_alive)
print('start server', server_id) print('start server', server_id)
# verify dtype of underlying graph
cg = g.client_g
for k, dtype in dgl.distributed.dist_graph.FIELD_DICT.items():
if k in cg.ndata:
assert F.dtype(
cg.ndata[k]) == dtype, "Data type of {} in ndata should be {}.".format(k, dtype)
if k in cg.edata:
assert F.dtype(
cg.edata[k]) == dtype, "Data type of {} in edata should be {}.".format(k, dtype)
g.start() g.start()
def emb_init(shape, dtype): def emb_init(shape, dtype):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment