Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
8e6cbd62
Unverified
Commit
8e6cbd62
authored
Feb 09, 2024
by
Rhett Ying
Committed by
GitHub
Feb 09, 2024
Browse files
[DistGB] sample with graphbolt on heterograph via DistNodeDataLoader (#7112)
parent
924c5669
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
67 additions
and
33 deletions
+67
-33
python/dgl/distributed/dist_graph.py
python/dgl/distributed/dist_graph.py
+16
-17
tests/distributed/test_mp_dataloader.py
tests/distributed/test_mp_dataloader.py
+51
-16
No files found.
python/dgl/distributed/dist_graph.py
View file @
8e6cbd62
...
@@ -622,18 +622,7 @@ class DistGraph:
...
@@ -622,18 +622,7 @@ class DistGraph:
self
.
_init_ndata_store
()
self
.
_init_ndata_store
()
self
.
_init_edata_store
()
self
.
_init_edata_store
()
self
.
_init_metadata
()
self
.
_num_nodes
=
0
self
.
_num_edges
=
0
for
part_md
in
self
.
_gpb
.
metadata
():
self
.
_num_nodes
+=
int
(
part_md
[
"num_nodes"
])
self
.
_num_edges
+=
int
(
part_md
[
"num_edges"
])
# When we store node/edge types in a list, they are stored in the order of type IDs.
self
.
_ntype_map
=
{
ntype
:
i
for
i
,
ntype
in
enumerate
(
self
.
ntypes
)}
self
.
_etype_map
=
{
etype
:
i
for
i
,
etype
in
enumerate
(
self
.
canonical_etypes
)
}
def
_init
(
self
,
gpb
):
def
_init
(
self
,
gpb
):
self
.
_client
=
get_kvstore
()
self
.
_client
=
get_kvstore
()
...
@@ -698,6 +687,19 @@ class DistGraph:
...
@@ -698,6 +687,19 @@ class DistGraph:
else
:
else
:
self
.
_edata_store
[
etype
]
=
data
self
.
_edata_store
[
etype
]
=
data
def
_init_metadata
(
self
):
self
.
_num_nodes
=
0
self
.
_num_edges
=
0
for
part_md
in
self
.
_gpb
.
metadata
():
self
.
_num_nodes
+=
int
(
part_md
[
"num_nodes"
])
self
.
_num_edges
+=
int
(
part_md
[
"num_edges"
])
# When we store node/edge types in a list, they are stored in the order of type IDs.
self
.
_ntype_map
=
{
ntype
:
i
for
i
,
ntype
in
enumerate
(
self
.
ntypes
)}
self
.
_etype_map
=
{
etype
:
i
for
i
,
etype
in
enumerate
(
self
.
canonical_etypes
)
}
def
__getstate__
(
self
):
def
__getstate__
(
self
):
return
self
.
graph_name
,
self
.
_gpb
,
self
.
_use_graphbolt
return
self
.
graph_name
,
self
.
_gpb
,
self
.
_use_graphbolt
...
@@ -707,11 +709,7 @@ class DistGraph:
...
@@ -707,11 +709,7 @@ class DistGraph:
self
.
_init_ndata_store
()
self
.
_init_ndata_store
()
self
.
_init_edata_store
()
self
.
_init_edata_store
()
self
.
_num_nodes
=
0
self
.
_init_metadata
()
self
.
_num_edges
=
0
for
part_md
in
self
.
_gpb
.
metadata
():
self
.
_num_nodes
+=
int
(
part_md
[
"num_nodes"
])
self
.
_num_edges
+=
int
(
part_md
[
"num_edges"
])
@
property
@
property
def
local_partition
(
self
):
def
local_partition
(
self
):
...
@@ -1403,6 +1401,7 @@ class DistGraph:
...
@@ -1403,6 +1401,7 @@ class DistGraph:
replace
=
replace
,
replace
=
replace
,
etype_sorted
=
etype_sorted
,
etype_sorted
=
etype_sorted
,
prob
=
prob
,
prob
=
prob
,
use_graphbolt
=
self
.
_use_graphbolt
,
)
)
else
:
else
:
frontier
=
graph_services
.
sample_neighbors
(
frontier
=
graph_services
.
sample_neighbors
(
...
...
tests/distributed/test_mp_dataloader.py
View file @
8e6cbd62
...
@@ -487,22 +487,23 @@ def start_node_dataloader(
...
@@ -487,22 +487,23 @@ def start_node_dataloader(
range
(
0
,
num_nodes_to_sample
,
batch_size
),
dataloader
range
(
0
,
num_nodes_to_sample
,
batch_size
),
dataloader
):
):
block
=
blocks
[
-
1
]
block
=
blocks
[
-
1
]
for
src_type
,
etype
,
dst_type
in
block
.
canonical_etypes
:
for
c_etype
in
block
.
canonical_etypes
:
o_src
,
o_dst
=
block
.
edges
(
etype
=
etype
)
src_type
,
_
,
dst_type
=
c_etype
o_src
,
o_dst
=
block
.
edges
(
etype
=
c_etype
)
src_nodes_id
=
block
.
srcnodes
[
src_type
].
data
[
dgl
.
NID
][
o_src
]
src_nodes_id
=
block
.
srcnodes
[
src_type
].
data
[
dgl
.
NID
][
o_src
]
dst_nodes_id
=
block
.
dstnodes
[
dst_type
].
data
[
dgl
.
NID
][
o_dst
]
dst_nodes_id
=
block
.
dstnodes
[
dst_type
].
data
[
dgl
.
NID
][
o_dst
]
src_nodes_id
=
orig_nid
[
src_type
][
src_nodes_id
]
src_nodes_id
=
orig_nid
[
src_type
][
src_nodes_id
]
dst_nodes_id
=
orig_nid
[
dst_type
][
dst_nodes_id
]
dst_nodes_id
=
orig_nid
[
dst_type
][
dst_nodes_id
]
has_edges
=
groundtruth_g
.
has_edges_between
(
has_edges
=
groundtruth_g
.
has_edges_between
(
src_nodes_id
,
dst_nodes_id
,
etype
=
etype
src_nodes_id
,
dst_nodes_id
,
etype
=
c_
etype
)
)
assert
np
.
all
(
F
.
asnumpy
(
has_edges
))
assert
np
.
all
(
F
.
asnumpy
(
has_edges
))
if
use_graphbolt
and
not
return_eids
:
if
use_graphbolt
and
not
return_eids
:
continue
continue
eids
=
orig_eid
[
etype
][
block
.
edata
[
dgl
.
EID
]]
eids
=
orig_eid
[
c_
etype
][
block
.
e
dges
[
c_etype
].
data
[
dgl
.
EID
]]
expected_eids
=
groundtruth_g
.
edge_ids
(
expected_eids
=
groundtruth_g
.
edge_ids
(
src_nodes_id
,
dst_nodes_id
src_nodes_id
,
dst_nodes_id
,
etype
=
c_etype
)
)
assert
th
.
equal
(
assert
th
.
equal
(
eids
,
expected_eids
eids
,
expected_eids
...
@@ -610,7 +611,7 @@ def check_dataloader(
...
@@ -610,7 +611,7 @@ def check_dataloader(
if
not
isinstance
(
orig_nid
,
dict
):
if
not
isinstance
(
orig_nid
,
dict
):
orig_nid
=
{
g
.
ntypes
[
0
]:
orig_nid
}
orig_nid
=
{
g
.
ntypes
[
0
]:
orig_nid
}
if
not
isinstance
(
orig_eid
,
dict
):
if
not
isinstance
(
orig_eid
,
dict
):
orig_eid
=
{
g
.
etypes
[
0
]:
orig_eid
}
orig_eid
=
{
g
.
canonical_
etypes
[
0
]:
orig_eid
}
pserver_list
=
[]
pserver_list
=
[]
ctx
=
mp
.
get_context
(
"spawn"
)
ctx
=
mp
.
get_context
(
"spawn"
)
...
@@ -718,14 +719,27 @@ def test_dataloader_homograph(
...
@@ -718,14 +719,27 @@ def test_dataloader_homograph(
)
)
@
unittest
.
skip
(
reason
=
"Skip due to glitch in CI"
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
1
])
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
1
])
@
pytest
.
mark
.
parametrize
(
"dataloader_type"
,
[
"node"
,
"edge"
])
@
pytest
.
mark
.
parametrize
(
"dataloader_type"
,
[
"node"
,
"edge"
])
def
test_dataloader_heterograph
(
num_server
,
num_workers
,
dataloader_type
):
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
def
test_dataloader_heterograph
(
num_server
,
num_workers
,
dataloader_type
,
use_graphbolt
,
return_eids
):
if
dataloader_type
==
"edge"
and
use_graphbolt
:
# GraphBolt does not support edge dataloader.
return
reset_envs
()
reset_envs
()
g
=
create_random_hetero
()
g
=
create_random_hetero
()
check_dataloader
(
g
,
num_server
,
num_workers
,
dataloader_type
)
check_dataloader
(
g
,
num_server
,
num_workers
,
dataloader_type
,
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
@
unittest
.
skip
(
reason
=
"Skip due to glitch in CI"
)
@
unittest
.
skip
(
reason
=
"Skip due to glitch in CI"
)
...
@@ -740,10 +754,18 @@ def test_neg_dataloader(num_server, num_workers):
...
@@ -740,10 +754,18 @@ def test_neg_dataloader(num_server, num_workers):
def
start_multiple_dataloaders
(
def
start_multiple_dataloaders
(
ip_config
,
part_config
,
graph_name
,
orig_g
,
num_dataloaders
,
dataloader_type
ip_config
,
part_config
,
graph_name
,
orig_g
,
num_dataloaders
,
dataloader_type
,
use_graphbolt
,
):
):
dgl
.
distributed
.
initialize
(
ip_config
)
dgl
.
distributed
.
initialize
(
ip_config
)
dist_g
=
dgl
.
distributed
.
DistGraph
(
graph_name
,
part_config
=
part_config
)
dist_g
=
dgl
.
distributed
.
DistGraph
(
graph_name
,
part_config
=
part_config
,
use_graphbolt
=
use_graphbolt
)
if
dataloader_type
==
"node"
:
if
dataloader_type
==
"node"
:
train_ids
=
th
.
arange
(
orig_g
.
num_nodes
())
train_ids
=
th
.
arange
(
orig_g
.
num_nodes
())
batch_size
=
orig_g
.
num_nodes
()
//
100
batch_size
=
orig_g
.
num_nodes
()
//
100
...
@@ -777,13 +799,17 @@ def start_multiple_dataloaders(
...
@@ -777,13 +799,17 @@ def start_multiple_dataloaders(
dgl
.
distributed
.
exit_client
()
dgl
.
distributed
.
exit_client
()
@
unittest
.
skip
(
reason
=
"Skip due to glitch in CI"
)
@
pytest
.
mark
.
parametrize
(
"num_dataloaders"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"num_dataloaders"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
1
])
@
pytest
.
mark
.
parametrize
(
"dataloader_type"
,
[
"node"
,
"edge"
])
@
pytest
.
mark
.
parametrize
(
"dataloader_type"
,
[
"node"
,
"edge"
])
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
def
test_multiple_dist_dataloaders
(
def
test_multiple_dist_dataloaders
(
num_dataloaders
,
num_workers
,
dataloader_type
num_dataloaders
,
num_workers
,
dataloader_type
,
use_graphbolt
,
return_eids
):
):
if
dataloader_type
==
"edge"
and
use_graphbolt
:
# GraphBolt does not support edge dataloader.
return
reset_envs
()
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
os
.
environ
[
"DGL_NUM_SAMPLER"
]
=
str
(
num_workers
)
os
.
environ
[
"DGL_NUM_SAMPLER"
]
=
str
(
num_workers
)
...
@@ -794,8 +820,15 @@ def test_multiple_dist_dataloaders(
...
@@ -794,8 +820,15 @@ def test_multiple_dist_dataloaders(
generate_ip_config
(
ip_config
,
num_parts
,
num_servers
)
generate_ip_config
(
ip_config
,
num_parts
,
num_servers
)
orig_g
=
dgl
.
rand_graph
(
1000
,
10000
)
orig_g
=
dgl
.
rand_graph
(
1000
,
10000
)
graph_name
=
"test"
graph_name
=
"test_multiple_dataloaders"
partition_graph
(
orig_g
,
graph_name
,
num_parts
,
test_dir
)
partition_graph
(
orig_g
,
graph_name
,
num_parts
,
test_dir
,
use_graphbolt
=
use_graphbolt
,
store_eids
=
return_eids
,
)
part_config
=
os
.
path
.
join
(
test_dir
,
f
"
{
graph_name
}
.json"
)
part_config
=
os
.
path
.
join
(
test_dir
,
f
"
{
graph_name
}
.json"
)
p_servers
=
[]
p_servers
=
[]
...
@@ -809,6 +842,7 @@ def test_multiple_dist_dataloaders(
...
@@ -809,6 +842,7 @@ def test_multiple_dist_dataloaders(
part_config
,
part_config
,
num_servers
>
1
,
num_servers
>
1
,
num_workers
+
1
,
num_workers
+
1
,
use_graphbolt
,
),
),
)
)
p
.
start
()
p
.
start
()
...
@@ -824,6 +858,7 @@ def test_multiple_dist_dataloaders(
...
@@ -824,6 +858,7 @@ def test_multiple_dist_dataloaders(
orig_g
,
orig_g
,
num_dataloaders
,
num_dataloaders
,
dataloader_type
,
dataloader_type
,
use_graphbolt
,
),
),
)
)
p_client
.
start
()
p_client
.
start
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment