Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
b0080d5b
Unverified
Commit
b0080d5b
authored
Feb 22, 2024
by
Rhett Ying
Committed by
GitHub
Feb 22, 2024
Browse files
[DistGB] add graphbolt flag into top level API (#7122)
parent
7a10bcb6
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
88 additions
and
43 deletions
+88
-43
python/dgl/distributed/__init__.py
python/dgl/distributed/__init__.py
+1
-0
python/dgl/distributed/constants.py
python/dgl/distributed/constants.py
+3
-0
python/dgl/distributed/dist_context.py
python/dgl/distributed/dist_context.py
+10
-1
python/dgl/distributed/dist_graph.py
python/dgl/distributed/dist_graph.py
+53
-17
python/dgl/distributed/server_state.py
python/dgl/distributed/server_state.py
+9
-1
tests/distributed/test_dist_graph_store.py
tests/distributed/test_dist_graph_store.py
+4
-4
tests/distributed/test_distributed_sampling.py
tests/distributed/test_distributed_sampling.py
+5
-15
tests/distributed/test_mp_dataloader.py
tests/distributed/test_mp_dataloader.py
+3
-5
No files found.
python/dgl/distributed/__init__.py
View file @
b0080d5b
...
@@ -19,3 +19,4 @@ from .rpc import *
...
@@ -19,3 +19,4 @@ from .rpc import *
from
.rpc_client
import
connect_to_server
from
.rpc_client
import
connect_to_server
from
.rpc_server
import
start_server
from
.rpc_server
import
start_server
from
.server_state
import
ServerState
from
.server_state
import
ServerState
from
.constants
import
*
python/dgl/distributed/constants.py
View file @
b0080d5b
...
@@ -7,3 +7,6 @@ SERVER_EXIT = "server_exit"
...
@@ -7,3 +7,6 @@ SERVER_EXIT = "server_exit"
DEFAULT_NTYPE
=
"_N"
DEFAULT_NTYPE
=
"_N"
DEFAULT_ETYPE
=
(
DEFAULT_NTYPE
,
"_E"
,
DEFAULT_NTYPE
)
DEFAULT_ETYPE
=
(
DEFAULT_NTYPE
,
"_E"
,
DEFAULT_NTYPE
)
DATA_LOADING_BACKEND_DGL
=
"DGL"
DATA_LOADING_BACKEND_GRAPHBOLT
=
"GraphBolt"
python/dgl/distributed/dist_context.py
View file @
b0080d5b
...
@@ -14,7 +14,11 @@ from enum import Enum
...
@@ -14,7 +14,11 @@ from enum import Enum
from
..
import
utils
from
..
import
utils
from
..base
import
dgl_warning
,
DGLError
from
..base
import
dgl_warning
,
DGLError
from
.
import
rpc
from
.
import
rpc
from
.constants
import
MAX_QUEUE_SIZE
from
.constants
import
(
DATA_LOADING_BACKEND_DGL
,
DATA_LOADING_BACKEND_GRAPHBOLT
,
MAX_QUEUE_SIZE
,
)
from
.kvstore
import
close_kvstore
,
init_kvstore
from
.kvstore
import
close_kvstore
,
init_kvstore
from
.role
import
init_role
from
.role
import
init_role
from
.rpc_client
import
connect_to_server
from
.rpc_client
import
connect_to_server
...
@@ -210,6 +214,7 @@ def initialize(
...
@@ -210,6 +214,7 @@ def initialize(
max_queue_size
=
MAX_QUEUE_SIZE
,
max_queue_size
=
MAX_QUEUE_SIZE
,
net_type
=
None
,
net_type
=
None
,
num_worker_threads
=
1
,
num_worker_threads
=
1
,
data_loading_backend
=
DATA_LOADING_BACKEND_DGL
,
):
):
"""Initialize DGL's distributed module
"""Initialize DGL's distributed module
...
@@ -231,6 +236,8 @@ def initialize(
...
@@ -231,6 +236,8 @@ def initialize(
[Deprecated] Networking type, can be 'socket' only.
[Deprecated] Networking type, can be 'socket' only.
num_worker_threads: int
num_worker_threads: int
The number of OMP threads in each sampler process.
The number of OMP threads in each sampler process.
data_loading_backend: str, optional
The backend for data loading. Can be 'DGL' or 'GraphBolt'.
Note
Note
----
----
...
@@ -270,6 +277,8 @@ def initialize(
...
@@ -270,6 +277,8 @@ def initialize(
int
(
os
.
environ
.
get
(
"DGL_NUM_CLIENT"
)),
int
(
os
.
environ
.
get
(
"DGL_NUM_CLIENT"
)),
os
.
environ
.
get
(
"DGL_CONF_PATH"
),
os
.
environ
.
get
(
"DGL_CONF_PATH"
),
graph_format
=
formats
,
graph_format
=
formats
,
use_graphbolt
=
data_loading_backend
==
DATA_LOADING_BACKEND_GRAPHBOLT
,
)
)
serv
.
start
()
serv
.
start
()
sys
.
exit
()
sys
.
exit
()
...
...
python/dgl/distributed/dist_graph.py
View file @
b0080d5b
...
@@ -18,6 +18,7 @@ from ..heterograph import DGLGraph
...
@@ -18,6 +18,7 @@ from ..heterograph import DGLGraph
from
..ndarray
import
exist_shared_mem_array
from
..ndarray
import
exist_shared_mem_array
from
..transforms
import
compact_graphs
from
..transforms
import
compact_graphs
from
.
import
graph_services
,
role
,
rpc
from
.
import
graph_services
,
role
,
rpc
from
.constants
import
DATA_LOADING_BACKEND_DGL
,
DATA_LOADING_BACKEND_GRAPHBOLT
from
.dist_tensor
import
DistTensor
from
.dist_tensor
import
DistTensor
from
.graph_partition_book
import
(
from
.graph_partition_book
import
(
_etype_str_to_tuple
,
_etype_str_to_tuple
,
...
@@ -50,6 +51,7 @@ from .shared_mem_utils import (
...
@@ -50,6 +51,7 @@ from .shared_mem_utils import (
)
)
INIT_GRAPH
=
800001
INIT_GRAPH
=
800001
QUERY_DATA_LOADING_BACKEND
=
800002
class
InitGraphRequest
(
rpc
.
Request
):
class
InitGraphRequest
(
rpc
.
Request
):
...
@@ -60,20 +62,19 @@ class InitGraphRequest(rpc.Request):
...
@@ -60,20 +62,19 @@ class InitGraphRequest(rpc.Request):
with shared memory.
with shared memory.
"""
"""
def
__init__
(
self
,
graph_name
,
use_graphbolt
):
def
__init__
(
self
,
graph_name
):
self
.
_graph_name
=
graph_name
self
.
_graph_name
=
graph_name
self
.
_use_graphbolt
=
use_graphbolt
def
__getstate__
(
self
):
def
__getstate__
(
self
):
return
self
.
_graph_name
,
self
.
_use_graphbolt
return
self
.
_graph_name
def
__setstate__
(
self
,
state
):
def
__setstate__
(
self
,
state
):
self
.
_graph_name
,
self
.
_use_graphbolt
=
state
self
.
_graph_name
=
state
def
process_request
(
self
,
server_state
):
def
process_request
(
self
,
server_state
):
if
server_state
.
graph
is
None
:
if
server_state
.
graph
is
None
:
server_state
.
graph
=
_get_graph_from_shared_mem
(
server_state
.
graph
=
_get_graph_from_shared_mem
(
self
.
_graph_name
,
se
lf
.
_
use_graphbolt
self
.
_graph_name
,
se
rver_state
.
use_graphbolt
)
)
return
InitGraphResponse
(
self
.
_graph_name
)
return
InitGraphResponse
(
self
.
_graph_name
)
...
@@ -91,6 +92,37 @@ class InitGraphResponse(rpc.Response):
...
@@ -91,6 +92,37 @@ class InitGraphResponse(rpc.Response):
self
.
_graph_name
=
state
self
.
_graph_name
=
state
class
QueryDataLoadingBackendRequest
(
rpc
.
Request
):
"""Query the data loading backend."""
def
__getstate__
(
self
):
return
None
def
__setstate__
(
self
,
state
):
pass
def
process_request
(
self
,
server_state
):
backend
=
(
DATA_LOADING_BACKEND_GRAPHBOLT
if
server_state
.
use_graphbolt
else
DATA_LOADING_BACKEND_DGL
)
return
QueryDataLoadingBackendResponse
(
backend
)
class
QueryDataLoadingBackendResponse
(
rpc
.
Response
):
"""Ack the query data loading backend request"""
def
__init__
(
self
,
backend
):
self
.
_backend
=
backend
def
__getstate__
(
self
):
return
self
.
_backend
def
__setstate__
(
self
,
state
):
self
.
_backend
=
state
def
_copy_graph_to_shared_mem
(
g
,
graph_name
,
graph_format
,
use_graphbolt
):
def
_copy_graph_to_shared_mem
(
g
,
graph_name
,
graph_format
,
use_graphbolt
):
if
use_graphbolt
:
if
use_graphbolt
:
return
g
.
copy_to_shared_memory
(
graph_name
)
return
g
.
copy_to_shared_memory
(
graph_name
)
...
@@ -473,6 +505,7 @@ class DistGraphServer(KVServer):
...
@@ -473,6 +505,7 @@ class DistGraphServer(KVServer):
kv_store
=
self
,
kv_store
=
self
,
local_g
=
self
.
client_g
,
local_g
=
self
.
client_g
,
partition_book
=
self
.
gpb
,
partition_book
=
self
.
gpb
,
use_graphbolt
=
self
.
use_graphbolt
,
)
)
print
(
print
(
"start graph service on server {} for part {}"
.
format
(
"start graph service on server {} for part {}"
.
format
(
...
@@ -529,8 +562,6 @@ class DistGraph:
...
@@ -529,8 +562,6 @@ class DistGraph:
part_config : str, optional
part_config : str, optional
The path of partition configuration file generated by
The path of partition configuration file generated by
:py:meth:`dgl.distributed.partition.partition_graph`. It's used in the standalone mode.
:py:meth:`dgl.distributed.partition.partition_graph`. It's used in the standalone mode.
use_graphbolt : bool, optional
Whether to load GraphBolt partition. Default: False.
Examples
Examples
--------
--------
...
@@ -564,15 +595,11 @@ class DistGraph:
...
@@ -564,15 +595,11 @@ class DistGraph:
manually setting up servers and trainers. The setup is not fully tested yet.
manually setting up servers and trainers. The setup is not fully tested yet.
"""
"""
def
__init__
(
def
__init__
(
self
,
graph_name
,
gpb
=
None
,
part_config
=
None
):
self
,
graph_name
,
gpb
=
None
,
part_config
=
None
,
use_graphbolt
=
False
):
self
.
graph_name
=
graph_name
self
.
graph_name
=
graph_name
self
.
_use_graphbolt
=
use_graphbolt
if
os
.
environ
.
get
(
"DGL_DIST_MODE"
,
"standalone"
)
==
"standalone"
:
if
os
.
environ
.
get
(
"DGL_DIST_MODE"
,
"standalone"
)
==
"standalone"
:
assert
(
# "GraphBolt is not supported in standalone mode."
use_graphbolt
is
False
self
.
_use_graphbolt
=
False
),
"GraphBolt is not supported in standalone mode."
assert
(
assert
(
part_config
is
not
None
part_config
is
not
None
),
"When running in the standalone model, the partition config file is required"
),
"When running in the standalone model, the partition config file is required"
...
@@ -610,12 +637,16 @@ class DistGraph:
...
@@ -610,12 +637,16 @@ class DistGraph:
self
.
_client
.
map_shared_data
(
self
.
_gpb
)
self
.
_client
.
map_shared_data
(
self
.
_gpb
)
rpc
.
set_num_client
(
1
)
rpc
.
set_num_client
(
1
)
else
:
else
:
# Query the main server about whether GraphBolt is used.
rpc
.
send_request
(
0
,
QueryDataLoadingBackendRequest
())
self
.
_use_graphbolt
=
(
rpc
.
recv_response
().
_backend
==
DATA_LOADING_BACKEND_GRAPHBOLT
)
self
.
_init
(
gpb
)
self
.
_init
(
gpb
)
# Tell the backup servers to load the graph structure from shared memory.
# Tell the backup servers to load the graph structure from shared memory.
for
server_id
in
range
(
self
.
_client
.
num_servers
):
for
server_id
in
range
(
self
.
_client
.
num_servers
):
rpc
.
send_request
(
rpc
.
send_request
(
server_id
,
InitGraphRequest
(
graph_name
))
server_id
,
InitGraphRequest
(
graph_name
,
use_graphbolt
)
)
for
server_id
in
range
(
self
.
_client
.
num_servers
):
for
server_id
in
range
(
self
.
_client
.
num_servers
):
rpc
.
recv_response
()
rpc
.
recv_response
()
self
.
_client
.
barrier
()
self
.
_client
.
barrier
()
...
@@ -1832,3 +1863,8 @@ def edge_split(
...
@@ -1832,3 +1863,8 @@ def edge_split(
rpc
.
register_service
(
INIT_GRAPH
,
InitGraphRequest
,
InitGraphResponse
)
rpc
.
register_service
(
INIT_GRAPH
,
InitGraphRequest
,
InitGraphResponse
)
rpc
.
register_service
(
QUERY_DATA_LOADING_BACKEND
,
QueryDataLoadingBackendRequest
,
QueryDataLoadingBackendResponse
,
)
python/dgl/distributed/server_state.py
View file @
b0080d5b
...
@@ -38,13 +38,16 @@ class ServerState:
...
@@ -38,13 +38,16 @@ class ServerState:
Total number of edges
Total number of edges
partition_book : GraphPartitionBook
partition_book : GraphPartitionBook
Graph Partition book
Graph Partition book
use_graphbolt : bool
Whether to use graphbolt for dataloading.
"""
"""
def
__init__
(
self
,
kv_store
,
local_g
,
partition_book
):
def
__init__
(
self
,
kv_store
,
local_g
,
partition_book
,
use_graphbolt
=
False
):
self
.
_kv_store
=
kv_store
self
.
_kv_store
=
kv_store
self
.
_graph
=
local_g
self
.
_graph
=
local_g
self
.
partition_book
=
partition_book
self
.
partition_book
=
partition_book
self
.
_roles
=
{}
self
.
_roles
=
{}
self
.
_use_graphbolt
=
use_graphbolt
@
property
@
property
def
roles
(
self
):
def
roles
(
self
):
...
@@ -69,5 +72,10 @@ class ServerState:
...
@@ -69,5 +72,10 @@ class ServerState:
def
graph
(
self
,
graph
):
def
graph
(
self
,
graph
):
self
.
_graph
=
graph
self
.
_graph
=
graph
@
property
def
use_graphbolt
(
self
):
"""Whether to use graphbolt for dataloading."""
return
self
.
_use_graphbolt
_init_api
(
"dgl.distributed.server_state"
)
_init_api
(
"dgl.distributed.server_state"
)
tests/distributed/test_dist_graph_store.py
View file @
b0080d5b
...
@@ -141,7 +141,7 @@ def run_client_empty(
...
@@ -141,7 +141,7 @@ def run_client_empty(
gpb
,
graph_name
,
_
,
_
=
load_partition_book
(
gpb
,
graph_name
,
_
,
_
=
load_partition_book
(
"/tmp/dist_graph/{}.json"
.
format
(
graph_name
),
part_id
"/tmp/dist_graph/{}.json"
.
format
(
graph_name
),
part_id
)
)
g
=
DistGraph
(
graph_name
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
g
=
DistGraph
(
graph_name
,
gpb
=
gpb
)
check_dist_graph_empty
(
g
,
num_clients
,
num_nodes
,
num_edges
)
check_dist_graph_empty
(
g
,
num_clients
,
num_nodes
,
num_edges
)
...
@@ -222,7 +222,7 @@ def run_client(
...
@@ -222,7 +222,7 @@ def run_client(
gpb
,
graph_name
,
_
,
_
=
load_partition_book
(
gpb
,
graph_name
,
_
,
_
=
load_partition_book
(
"/tmp/dist_graph/{}.json"
.
format
(
graph_name
),
part_id
"/tmp/dist_graph/{}.json"
.
format
(
graph_name
),
part_id
)
)
g
=
DistGraph
(
graph_name
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
g
=
DistGraph
(
graph_name
,
gpb
=
gpb
)
check_dist_graph
(
check_dist_graph
(
g
,
num_clients
,
num_nodes
,
num_edges
,
use_graphbolt
=
use_graphbolt
g
,
num_clients
,
num_nodes
,
num_edges
,
use_graphbolt
=
use_graphbolt
)
)
...
@@ -322,7 +322,7 @@ def run_client_hierarchy(
...
@@ -322,7 +322,7 @@ def run_client_hierarchy(
gpb
,
graph_name
,
_
,
_
=
load_partition_book
(
gpb
,
graph_name
,
_
,
_
=
load_partition_book
(
"/tmp/dist_graph/{}.json"
.
format
(
graph_name
),
part_id
"/tmp/dist_graph/{}.json"
.
format
(
graph_name
),
part_id
)
)
g
=
DistGraph
(
graph_name
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
g
=
DistGraph
(
graph_name
,
gpb
=
gpb
)
node_mask
=
F
.
tensor
(
node_mask
)
node_mask
=
F
.
tensor
(
node_mask
)
edge_mask
=
F
.
tensor
(
edge_mask
)
edge_mask
=
F
.
tensor
(
edge_mask
)
nodes
=
node_split
(
nodes
=
node_split
(
...
@@ -742,7 +742,7 @@ def run_client_hetero(
...
@@ -742,7 +742,7 @@ def run_client_hetero(
gpb
,
graph_name
,
_
,
_
=
load_partition_book
(
gpb
,
graph_name
,
_
,
_
=
load_partition_book
(
"/tmp/dist_graph/{}.json"
.
format
(
graph_name
),
part_id
"/tmp/dist_graph/{}.json"
.
format
(
graph_name
),
part_id
)
)
g
=
DistGraph
(
graph_name
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
g
=
DistGraph
(
graph_name
,
gpb
=
gpb
)
check_dist_graph_hetero
(
check_dist_graph_hetero
(
g
,
num_clients
,
num_nodes
,
num_edges
,
use_graphbolt
=
use_graphbolt
g
,
num_clients
,
num_nodes
,
num_edges
,
use_graphbolt
=
use_graphbolt
)
)
...
...
tests/distributed/test_distributed_sampling.py
View file @
b0080d5b
...
@@ -84,9 +84,7 @@ def start_sample_client_shuffle(
...
@@ -84,9 +84,7 @@ def start_sample_client_shuffle(
tmpdir
/
"test_sampling.json"
,
rank
tmpdir
/
"test_sampling.json"
,
rank
)
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dist_graph
=
DistGraph
(
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
)
"test_sampling"
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
sampled_graph
=
sample_neighbors
(
sampled_graph
=
sample_neighbors
(
dist_graph
,
[
0
,
10
,
99
,
66
,
1024
,
2008
],
3
,
use_graphbolt
=
use_graphbolt
dist_graph
,
[
0
,
10
,
99
,
66
,
1024
,
2008
],
3
,
use_graphbolt
=
use_graphbolt
)
)
...
@@ -477,9 +475,7 @@ def start_hetero_sample_client(
...
@@ -477,9 +475,7 @@ def start_hetero_sample_client(
tmpdir
/
"test_sampling.json"
,
rank
tmpdir
/
"test_sampling.json"
,
rank
)
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dist_graph
=
DistGraph
(
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
)
"test_sampling"
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
assert
"feat"
in
dist_graph
.
nodes
[
"n1"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"n1"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n2"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n2"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n3"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n3"
].
data
...
@@ -517,9 +513,7 @@ def start_hetero_etype_sample_client(
...
@@ -517,9 +513,7 @@ def start_hetero_etype_sample_client(
tmpdir
/
"test_sampling.json"
,
rank
tmpdir
/
"test_sampling.json"
,
rank
)
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dist_graph
=
DistGraph
(
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
)
"test_sampling"
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
assert
"feat"
in
dist_graph
.
nodes
[
"n1"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"n1"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n2"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n2"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n3"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n3"
].
data
...
@@ -876,9 +870,7 @@ def start_bipartite_sample_client(
...
@@ -876,9 +870,7 @@ def start_bipartite_sample_client(
tmpdir
/
"test_sampling.json"
,
rank
tmpdir
/
"test_sampling.json"
,
rank
)
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dist_graph
=
DistGraph
(
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
)
"test_sampling"
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
assert
"feat"
in
dist_graph
.
nodes
[
"user"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"user"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"game"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"game"
].
data
if
gpb
is
None
:
if
gpb
is
None
:
...
@@ -911,9 +903,7 @@ def start_bipartite_etype_sample_client(
...
@@ -911,9 +903,7 @@ def start_bipartite_etype_sample_client(
tmpdir
/
"test_sampling.json"
,
rank
tmpdir
/
"test_sampling.json"
,
rank
)
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dist_graph
=
DistGraph
(
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
)
"test_sampling"
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
assert
"feat"
in
dist_graph
.
nodes
[
"user"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"user"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"game"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"game"
].
data
...
...
tests/distributed/test_mp_dataloader.py
View file @
b0080d5b
...
@@ -11,6 +11,8 @@ import pytest
...
@@ -11,6 +11,8 @@ import pytest
import
torch
as
th
import
torch
as
th
from
dgl.data
import
CitationGraphDataset
from
dgl.data
import
CitationGraphDataset
from
dgl.distributed
import
(
from
dgl.distributed
import
(
DATA_LOADING_BACKEND_DGL
,
DATA_LOADING_BACKEND_GRAPHBOLT
,
DistDataLoader
,
DistDataLoader
,
DistGraph
,
DistGraph
,
DistGraphServer
,
DistGraphServer
,
...
@@ -104,7 +106,6 @@ def start_dist_dataloader(
...
@@ -104,7 +106,6 @@ def start_dist_dataloader(
"test_sampling"
,
"test_sampling"
,
gpb
=
gpb
,
gpb
=
gpb
,
part_config
=
part_config
,
part_config
=
part_config
,
use_graphbolt
=
use_graphbolt
,
)
)
# Create sampler
# Create sampler
...
@@ -443,7 +444,6 @@ def start_node_dataloader(
...
@@ -443,7 +444,6 @@ def start_node_dataloader(
"test_sampling"
,
"test_sampling"
,
gpb
=
gpb
,
gpb
=
gpb
,
part_config
=
part_config
,
part_config
=
part_config
,
use_graphbolt
=
use_graphbolt
,
)
)
assert
len
(
dist_graph
.
ntypes
)
==
len
(
groundtruth_g
.
ntypes
)
assert
len
(
dist_graph
.
ntypes
)
==
len
(
groundtruth_g
.
ntypes
)
assert
len
(
dist_graph
.
etypes
)
==
len
(
groundtruth_g
.
etypes
)
assert
len
(
dist_graph
.
etypes
)
==
len
(
groundtruth_g
.
etypes
)
...
@@ -763,9 +763,7 @@ def start_multiple_dataloaders(
...
@@ -763,9 +763,7 @@ def start_multiple_dataloaders(
use_graphbolt
,
use_graphbolt
,
):
):
dgl
.
distributed
.
initialize
(
ip_config
)
dgl
.
distributed
.
initialize
(
ip_config
)
dist_g
=
dgl
.
distributed
.
DistGraph
(
dist_g
=
dgl
.
distributed
.
DistGraph
(
graph_name
,
part_config
=
part_config
)
graph_name
,
part_config
=
part_config
,
use_graphbolt
=
use_graphbolt
)
if
dataloader_type
==
"node"
:
if
dataloader_type
==
"node"
:
train_ids
=
th
.
arange
(
orig_g
.
num_nodes
())
train_ids
=
th
.
arange
(
orig_g
.
num_nodes
())
batch_size
=
orig_g
.
num_nodes
()
//
100
batch_size
=
orig_g
.
num_nodes
()
//
100
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment