Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
68377251
"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "6921393ae27f7ab9f3f25f9e772ec42cfdf82f63"
Unverified
Commit
68377251
authored
Jan 30, 2024
by
Rhett Ying
Committed by
GitHub
Jan 30, 2024
Browse files
[DistGB] enable DistGraphServer to load graphbolt partitions (#7042)
parent
fe78093f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
25 deletions
+40
-25
python/dgl/distributed/dist_graph.py
python/dgl/distributed/dist_graph.py
+40
-25
No files found.
python/dgl/distributed/dist_graph.py
View file @
68377251
...
@@ -8,7 +8,7 @@ from collections.abc import MutableMapping
...
@@ -8,7 +8,7 @@ from collections.abc import MutableMapping
import
numpy
as
np
import
numpy
as
np
from
..
import
backend
as
F
,
heterograph_index
from
..
import
backend
as
F
,
graphbolt
as
gb
,
heterograph_index
from
.._ffi.ndarray
import
empty_shared_mem
from
.._ffi.ndarray
import
empty_shared_mem
from
..base
import
ALL
,
DGLError
,
EID
,
ETYPE
,
is_all
,
NID
from
..base
import
ALL
,
DGLError
,
EID
,
ETYPE
,
is_all
,
NID
from
..convert
import
graph
as
dgl_graph
,
heterograph
as
dgl_heterograph
from
..convert
import
graph
as
dgl_graph
,
heterograph
as
dgl_heterograph
...
@@ -88,7 +88,9 @@ class InitGraphResponse(rpc.Response):
...
@@ -88,7 +88,9 @@ class InitGraphResponse(rpc.Response):
self
.
_graph_name
=
state
self
.
_graph_name
=
state
def
_copy_graph_to_shared_mem
(
g
,
graph_name
,
graph_format
):
def
_copy_graph_to_shared_mem
(
g
,
graph_name
,
graph_format
,
use_graphbolt
):
if
use_graphbolt
:
return
g
.
copy_to_shared_memory
(
graph_name
)
new_g
=
g
.
shared_memory
(
graph_name
,
formats
=
graph_format
)
new_g
=
g
.
shared_memory
(
graph_name
,
formats
=
graph_format
)
# We should share the node/edge data to the client explicitly instead of putting them
# We should share the node/edge data to the client explicitly instead of putting them
# in the KVStore because some of the node/edge data may be duplicated.
# in the KVStore because some of the node/edge data may be duplicated.
...
@@ -298,6 +300,30 @@ class EdgeDataView(MutableMapping):
...
@@ -298,6 +300,30 @@ class EdgeDataView(MutableMapping):
return
repr
(
reprs
)
return
repr
(
reprs
)
def
_format_partition
(
graph
,
graph_format
):
"""Format the partition to the specified format."""
if
isinstance
(
graph
,
gb
.
FusedCSCSamplingGraph
):
return
graph
# formatting dtype
# TODO(Rui) Formatting forcely is not a perfect solution.
# We'd better store all dtypes when mapping to shared memory
# and map back with original dtypes.
for
k
,
dtype
in
RESERVED_FIELD_DTYPE
.
items
():
if
k
in
graph
.
ndata
:
graph
.
ndata
[
k
]
=
F
.
astype
(
graph
.
ndata
[
k
],
dtype
)
if
k
in
graph
.
edata
:
graph
.
edata
[
k
]
=
F
.
astype
(
graph
.
edata
[
k
],
dtype
)
# Create the graph formats specified the users.
print
(
"Start to create specified graph formats which may take "
"non-trivial time."
)
graph
=
graph
.
formats
(
graph_format
)
graph
.
create_formats_
()
print
(
f
"Finished creating specified graph formats:
{
graph_format
}
"
)
return
graph
class
DistGraphServer
(
KVServer
):
class
DistGraphServer
(
KVServer
):
"""The DistGraph server.
"""The DistGraph server.
...
@@ -330,6 +356,8 @@ class DistGraphServer(KVServer):
...
@@ -330,6 +356,8 @@ class DistGraphServer(KVServer):
Disable shared memory.
Disable shared memory.
graph_format : str or list of str
graph_format : str or list of str
The graph formats.
The graph formats.
use_graphbolt : bool
Whether to load GraphBolt partition. Default: False.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -341,6 +369,7 @@ class DistGraphServer(KVServer):
...
@@ -341,6 +369,7 @@ class DistGraphServer(KVServer):
part_config
,
part_config
,
disable_shared_mem
=
False
,
disable_shared_mem
=
False
,
graph_format
=
(
"csc"
,
"coo"
),
graph_format
=
(
"csc"
,
"coo"
),
use_graphbolt
=
False
,
):
):
super
(
DistGraphServer
,
self
).
__init__
(
super
(
DistGraphServer
,
self
).
__init__
(
server_id
=
server_id
,
server_id
=
server_id
,
...
@@ -350,6 +379,7 @@ class DistGraphServer(KVServer):
...
@@ -350,6 +379,7 @@ class DistGraphServer(KVServer):
)
)
self
.
ip_config
=
ip_config
self
.
ip_config
=
ip_config
self
.
num_servers
=
num_servers
self
.
num_servers
=
num_servers
self
.
use_graphbolt
=
use_graphbolt
# Load graph partition data.
# Load graph partition data.
if
self
.
is_backup_server
():
if
self
.
is_backup_server
():
# The backup server doesn't load the graph partition. It'll initialized afterwards.
# The backup server doesn't load the graph partition. It'll initialized afterwards.
...
@@ -367,32 +397,17 @@ class DistGraphServer(KVServer):
...
@@ -367,32 +397,17 @@ class DistGraphServer(KVServer):
graph_name
,
graph_name
,
ntypes
,
ntypes
,
etypes
,
etypes
,
)
=
load_partition
(
part_config
,
self
.
part_id
,
load_feats
=
False
)
)
=
load_partition
(
print
(
"load "
+
graph_name
)
part_config
,
# formatting dtype
self
.
part_id
,
# TODO(Rui) Formatting forcely is not a perfect solution.
load_feats
=
False
,
# We'd better store all dtypes when mapping to shared memory
use_graphbolt
=
use_graphbolt
,
# and map back with original dtypes.
for
k
,
dtype
in
RESERVED_FIELD_DTYPE
.
items
():
if
k
in
self
.
client_g
.
ndata
:
self
.
client_g
.
ndata
[
k
]
=
F
.
astype
(
self
.
client_g
.
ndata
[
k
],
dtype
)
if
k
in
self
.
client_g
.
edata
:
self
.
client_g
.
edata
[
k
]
=
F
.
astype
(
self
.
client_g
.
edata
[
k
],
dtype
)
# Create the graph formats specified the users.
print
(
"Start to create specified graph formats which may take "
"non-trivial time."
)
)
self
.
client_g
=
self
.
client_g
.
formats
(
graph_format
)
print
(
"load "
+
graph_name
)
self
.
client_g
.
create_formats_
()
self
.
client_g
=
_format_partition
(
self
.
client_g
,
graph_format
)
print
(
"Finished creating specified graph formats."
)
if
not
disable_shared_mem
:
if
not
disable_shared_mem
:
self
.
client_g
=
_copy_graph_to_shared_mem
(
self
.
client_g
=
_copy_graph_to_shared_mem
(
self
.
client_g
,
graph_name
,
graph_format
self
.
client_g
,
graph_name
,
graph_format
,
use_graphbolt
)
)
if
not
disable_shared_mem
:
if
not
disable_shared_mem
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment