Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
3ebdee77
Unverified
Commit
3ebdee77
authored
Feb 09, 2024
by
Rhett Ying
Committed by
GitHub
Feb 09, 2024
Browse files
[DistGB] sample with graphbolt on homograph via DistNodeDataLoader (#7108)
parent
7f7967b3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
73 additions
and
13 deletions
+73
-13
python/dgl/dataloading/neighbor_sampler.py
python/dgl/dataloading/neighbor_sampler.py
+4
-2
python/dgl/distributed/dist_graph.py
python/dgl/distributed/dist_graph.py
+6
-1
tests/distributed/test_mp_dataloader.py
tests/distributed/test_mp_dataloader.py
+63
-10
No files found.
python/dgl/dataloading/neighbor_sampler.py
View file @
3ebdee77
...
@@ -192,9 +192,11 @@ class NeighborSampler(BlockSampler):
...
@@ -192,9 +192,11 @@ class NeighborSampler(BlockSampler):
output_device
=
self
.
output_device
,
output_device
=
self
.
output_device
,
exclude_edges
=
exclude_eids
,
exclude_edges
=
exclude_eids
,
)
)
eid
=
frontier
.
edata
[
EID
]
block
=
to_block
(
frontier
,
seed_nodes
)
block
=
to_block
(
frontier
,
seed_nodes
)
block
.
edata
[
EID
]
=
eid
# If sampled from graphbolt-backed DistGraph, `EID` may not be in
# the block.
if
EID
in
frontier
.
edata
.
keys
():
block
.
edata
[
EID
]
=
frontier
.
edata
[
EID
]
seed_nodes
=
block
.
srcdata
[
NID
]
seed_nodes
=
block
.
srcdata
[
NID
]
blocks
.
insert
(
0
,
block
)
blocks
.
insert
(
0
,
block
)
...
...
python/dgl/distributed/dist_graph.py
View file @
3ebdee77
...
@@ -1406,7 +1406,12 @@ class DistGraph:
...
@@ -1406,7 +1406,12 @@ class DistGraph:
)
)
else
:
else
:
frontier
=
graph_services
.
sample_neighbors
(
frontier
=
graph_services
.
sample_neighbors
(
self
,
seed_nodes
,
fanout
,
replace
=
replace
,
prob
=
prob
self
,
seed_nodes
,
fanout
,
replace
=
replace
,
prob
=
prob
,
use_graphbolt
=
self
.
_use_graphbolt
,
)
)
return
frontier
return
frontier
...
...
tests/distributed/test_mp_dataloader.py
View file @
3ebdee77
...
@@ -342,7 +342,7 @@ def check_neg_dataloader(g, num_server, num_workers):
...
@@ -342,7 +342,7 @@ def check_neg_dataloader(g, num_server, num_workers):
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
4
])
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
1
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
...
@@ -429,6 +429,8 @@ def start_node_dataloader(
...
@@ -429,6 +429,8 @@ def start_node_dataloader(
orig_nid
,
orig_nid
,
orig_eid
,
orig_eid
,
groundtruth_g
,
groundtruth_g
,
use_graphbolt
=
False
,
return_eids
=
False
,
):
):
dgl
.
distributed
.
initialize
(
ip_config
)
dgl
.
distributed
.
initialize
(
ip_config
)
gpb
=
None
gpb
=
None
...
@@ -437,7 +439,12 @@ def start_node_dataloader(
...
@@ -437,7 +439,12 @@ def start_node_dataloader(
_
,
_
,
_
,
gpb
,
_
,
_
,
_
=
load_partition
(
part_config
,
rank
)
_
,
_
,
_
,
gpb
,
_
,
_
,
_
=
load_partition
(
part_config
,
rank
)
num_nodes_to_sample
=
202
num_nodes_to_sample
=
202
batch_size
=
32
batch_size
=
32
dist_graph
=
DistGraph
(
"test_mp"
,
gpb
=
gpb
,
part_config
=
part_config
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
,
part_config
=
part_config
,
use_graphbolt
=
use_graphbolt
,
)
assert
len
(
dist_graph
.
ntypes
)
==
len
(
groundtruth_g
.
ntypes
)
assert
len
(
dist_graph
.
ntypes
)
==
len
(
groundtruth_g
.
ntypes
)
assert
len
(
dist_graph
.
etypes
)
==
len
(
groundtruth_g
.
etypes
)
assert
len
(
dist_graph
.
etypes
)
==
len
(
groundtruth_g
.
etypes
)
if
len
(
dist_graph
.
etypes
)
==
1
:
if
len
(
dist_graph
.
etypes
)
==
1
:
...
@@ -459,6 +466,9 @@ def start_node_dataloader(
...
@@ -459,6 +466,9 @@ def start_node_dataloader(
]
]
)
# test int for hetero
)
# test int for hetero
# Enable santity check in distributed sampling.
os
.
environ
[
"DGL_DIST_DEBUG"
]
=
"1"
# We need to test creating DistDataLoader multiple times.
# We need to test creating DistDataLoader multiple times.
for
i
in
range
(
2
):
for
i
in
range
(
2
):
# Create DataLoader for constructing blocks
# Create DataLoader for constructing blocks
...
@@ -472,7 +482,7 @@ def start_node_dataloader(
...
@@ -472,7 +482,7 @@ def start_node_dataloader(
num_workers
=
num_workers
,
num_workers
=
num_workers
,
)
)
for
epoch
in
range
(
2
):
for
_
in
range
(
2
):
for
idx
,
(
_
,
_
,
blocks
)
in
zip
(
for
idx
,
(
_
,
_
,
blocks
)
in
zip
(
range
(
0
,
num_nodes_to_sample
,
batch_size
),
dataloader
range
(
0
,
num_nodes_to_sample
,
batch_size
),
dataloader
):
):
...
@@ -487,6 +497,16 @@ def start_node_dataloader(
...
@@ -487,6 +497,16 @@ def start_node_dataloader(
src_nodes_id
,
dst_nodes_id
,
etype
=
etype
src_nodes_id
,
dst_nodes_id
,
etype
=
etype
)
)
assert
np
.
all
(
F
.
asnumpy
(
has_edges
))
assert
np
.
all
(
F
.
asnumpy
(
has_edges
))
if
use_graphbolt
and
not
return_eids
:
continue
eids
=
orig_eid
[
etype
][
block
.
edata
[
dgl
.
EID
]]
expected_eids
=
groundtruth_g
.
edge_ids
(
src_nodes_id
,
dst_nodes_id
)
assert
th
.
equal
(
eids
,
expected_eids
),
f
"
{
eids
}
!=
{
expected_eids
}
"
del
dataloader
del
dataloader
# this is needed since there's two test here in one process
# this is needed since there's two test here in one process
dgl
.
distributed
.
exit_client
()
dgl
.
distributed
.
exit_client
()
...
@@ -509,7 +529,7 @@ def start_edge_dataloader(
...
@@ -509,7 +529,7 @@ def start_edge_dataloader(
_
,
_
,
_
,
gpb
,
_
,
_
,
_
=
load_partition
(
part_config
,
rank
)
_
,
_
,
_
,
gpb
,
_
,
_
,
_
=
load_partition
(
part_config
,
rank
)
num_edges_to_sample
=
202
num_edges_to_sample
=
202
batch_size
=
32
batch_size
=
32
dist_graph
=
DistGraph
(
"test_
mp
"
,
gpb
=
gpb
,
part_config
=
part_config
)
dist_graph
=
DistGraph
(
"test_
sampling
"
,
gpb
=
gpb
,
part_config
=
part_config
)
assert
len
(
dist_graph
.
ntypes
)
==
len
(
groundtruth_g
.
ntypes
)
assert
len
(
dist_graph
.
ntypes
)
==
len
(
groundtruth_g
.
ntypes
)
assert
len
(
dist_graph
.
etypes
)
==
len
(
groundtruth_g
.
etypes
)
assert
len
(
dist_graph
.
etypes
)
==
len
(
groundtruth_g
.
etypes
)
if
len
(
dist_graph
.
etypes
)
==
1
:
if
len
(
dist_graph
.
etypes
)
==
1
:
...
@@ -561,7 +581,14 @@ def start_edge_dataloader(
...
@@ -561,7 +581,14 @@ def start_edge_dataloader(
dgl
.
distributed
.
exit_client
()
dgl
.
distributed
.
exit_client
()
def
check_dataloader
(
g
,
num_server
,
num_workers
,
dataloader_type
):
def
check_dataloader
(
g
,
num_server
,
num_workers
,
dataloader_type
,
use_graphbolt
=
False
,
return_eids
=
False
,
):
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
ip_config
=
"ip_config.txt"
ip_config
=
"ip_config.txt"
generate_ip_config
(
ip_config
,
num_server
,
num_server
)
generate_ip_config
(
ip_config
,
num_server
,
num_server
)
...
@@ -576,6 +603,8 @@ def check_dataloader(g, num_server, num_workers, dataloader_type):
...
@@ -576,6 +603,8 @@ def check_dataloader(g, num_server, num_workers, dataloader_type):
num_hops
=
num_hops
,
num_hops
=
num_hops
,
part_method
=
"metis"
,
part_method
=
"metis"
,
return_mapping
=
True
,
return_mapping
=
True
,
use_graphbolt
=
use_graphbolt
,
store_eids
=
return_eids
,
)
)
part_config
=
os
.
path
.
join
(
test_dir
,
"test_sampling.json"
)
part_config
=
os
.
path
.
join
(
test_dir
,
"test_sampling.json"
)
if
not
isinstance
(
orig_nid
,
dict
):
if
not
isinstance
(
orig_nid
,
dict
):
...
@@ -594,6 +623,7 @@ def check_dataloader(g, num_server, num_workers, dataloader_type):
...
@@ -594,6 +623,7 @@ def check_dataloader(g, num_server, num_workers, dataloader_type):
part_config
,
part_config
,
num_server
>
1
,
num_server
>
1
,
num_workers
+
1
,
num_workers
+
1
,
use_graphbolt
,
),
),
)
)
p
.
start
()
p
.
start
()
...
@@ -615,6 +645,8 @@ def check_dataloader(g, num_server, num_workers, dataloader_type):
...
@@ -615,6 +645,8 @@ def check_dataloader(g, num_server, num_workers, dataloader_type):
orig_nid
,
orig_nid
,
orig_eid
,
orig_eid
,
g
,
g
,
use_graphbolt
,
return_eids
,
),
),
)
)
p
.
start
()
p
.
start
()
...
@@ -663,14 +695,35 @@ def create_random_hetero():
...
@@ -663,14 +695,35 @@ def create_random_hetero():
return
g
return
g
@
unittest
.
skip
(
reason
=
"Skip due to glitch in CI"
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
3
])
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
1
])
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
4
])
@
pytest
.
mark
.
parametrize
(
"dataloader_type"
,
[
"node"
,
"edge"
])
@
pytest
.
mark
.
parametrize
(
"dataloader_type"
,
[
"node"
,
"edge"
])
def
test_dataloader
(
num_server
,
num_workers
,
dataloader_type
):
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
def
test_dataloader_homograph
(
num_server
,
num_workers
,
dataloader_type
,
use_graphbolt
,
return_eids
):
if
dataloader_type
==
"edge"
and
use_graphbolt
:
# GraphBolt does not support edge dataloader.
return
reset_envs
()
reset_envs
()
g
=
CitationGraphDataset
(
"cora"
)[
0
]
g
=
CitationGraphDataset
(
"cora"
)[
0
]
check_dataloader
(
g
,
num_server
,
num_workers
,
dataloader_type
)
check_dataloader
(
g
,
num_server
,
num_workers
,
dataloader_type
,
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
@
unittest
.
skip
(
reason
=
"Skip due to glitch in CI"
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
1
])
@
pytest
.
mark
.
parametrize
(
"dataloader_type"
,
[
"node"
,
"edge"
])
def
test_dataloader_heterograph
(
num_server
,
num_workers
,
dataloader_type
):
reset_envs
()
g
=
create_random_hetero
()
g
=
create_random_hetero
()
check_dataloader
(
g
,
num_server
,
num_workers
,
dataloader_type
)
check_dataloader
(
g
,
num_server
,
num_workers
,
dataloader_type
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment