Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
763bd39f
"...text-generation-inference.git" did not exist on "85aa7e2e7b02608eea04206b6cc0fa0ccced80ef"
Unverified
Commit
763bd39f
authored
Feb 08, 2024
by
Rhett Ying
Committed by
GitHub
Feb 08, 2024
Browse files
[DistGB] sample with graphbolt on homograph via DistDataLoader (#7098)
parent
870d8d02
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
73 additions
and
45 deletions
+73
-45
tests/distributed/test_mp_dataloader.py
tests/distributed/test_mp_dataloader.py
+73
-45
No files found.
tests/distributed/test_mp_dataloader.py
View file @
763bd39f
...
@@ -22,10 +22,19 @@ from utils import generate_ip_config, reset_envs
...
@@ -22,10 +22,19 @@ from utils import generate_ip_config, reset_envs
class
NeighborSampler
(
object
):
class
NeighborSampler
(
object
):
def
__init__
(
self
,
g
,
fanouts
,
sample_neighbors
):
def
__init__
(
self
,
g
,
fanouts
,
sample_neighbors
,
use_graphbolt
=
False
,
return_eids
=
False
,
):
self
.
g
=
g
self
.
g
=
g
self
.
fanouts
=
fanouts
self
.
fanouts
=
fanouts
self
.
sample_neighbors
=
sample_neighbors
self
.
sample_neighbors
=
sample_neighbors
self
.
use_graphbolt
=
use_graphbolt
self
.
return_eids
=
return_eids
def
sample_blocks
(
self
,
seeds
):
def
sample_blocks
(
self
,
seeds
):
import
torch
as
th
import
torch
as
th
...
@@ -35,13 +44,16 @@ class NeighborSampler(object):
...
@@ -35,13 +44,16 @@ class NeighborSampler(object):
for
fanout
in
self
.
fanouts
:
for
fanout
in
self
.
fanouts
:
# For each seed node, sample ``fanout`` neighbors.
# For each seed node, sample ``fanout`` neighbors.
frontier
=
self
.
sample_neighbors
(
frontier
=
self
.
sample_neighbors
(
self
.
g
,
seeds
,
fanout
,
replace
=
True
self
.
g
,
seeds
,
fanout
,
use_graphbolt
=
self
.
use_graphbolt
)
)
# Then we compact the frontier into a bipartite graph for
# Then we compact the frontier into a bipartite graph for
# message passing.
# message passing.
block
=
dgl
.
to_block
(
frontier
,
seeds
)
block
=
dgl
.
to_block
(
frontier
,
seeds
)
# Obtain the seed nodes for next layer.
# Obtain the seed nodes for next layer.
seeds
=
block
.
srcdata
[
dgl
.
NID
]
seeds
=
block
.
srcdata
[
dgl
.
NID
]
if
frontier
.
num_edges
()
>
0
:
if
not
self
.
use_graphbolt
or
self
.
return_eids
:
block
.
edata
[
dgl
.
EID
]
=
frontier
.
edata
[
dgl
.
EID
]
blocks
.
insert
(
0
,
block
)
blocks
.
insert
(
0
,
block
)
return
blocks
return
blocks
...
@@ -53,6 +65,7 @@ def start_server(
...
@@ -53,6 +65,7 @@ def start_server(
part_config
,
part_config
,
disable_shared_mem
,
disable_shared_mem
,
num_clients
,
num_clients
,
use_graphbolt
=
False
,
):
):
print
(
"server: #clients="
+
str
(
num_clients
))
print
(
"server: #clients="
+
str
(
num_clients
))
g
=
DistGraphServer
(
g
=
DistGraphServer
(
...
@@ -63,6 +76,7 @@ def start_server(
...
@@ -63,6 +76,7 @@ def start_server(
part_config
,
part_config
,
disable_shared_mem
=
disable_shared_mem
,
disable_shared_mem
=
disable_shared_mem
,
graph_format
=
[
"csc"
,
"coo"
],
graph_format
=
[
"csc"
,
"coo"
],
use_graphbolt
=
use_graphbolt
,
)
)
g
.
start
()
g
.
start
()
...
@@ -75,30 +89,36 @@ def start_dist_dataloader(
...
@@ -75,30 +89,36 @@ def start_dist_dataloader(
drop_last
,
drop_last
,
orig_nid
,
orig_nid
,
orig_eid
,
orig_eid
,
group_id
=
0
,
use_graphbolt
=
False
,
return_eids
=
False
,
):
):
import
dgl
import
torch
as
th
os
.
environ
[
"DGL_GROUP_ID"
]
=
str
(
group_id
)
dgl
.
distributed
.
initialize
(
ip_config
)
dgl
.
distributed
.
initialize
(
ip_config
)
gpb
=
None
gpb
=
None
disable_shared_mem
=
num_server
>
0
disable_shared_mem
=
num_server
>
1
if
disable_shared_mem
:
if
disable_shared_mem
:
_
,
_
,
_
,
gpb
,
_
,
_
,
_
=
load_partition
(
part_config
,
rank
)
_
,
_
,
_
,
gpb
,
_
,
_
,
_
=
load_partition
(
part_config
,
rank
)
num_nodes_to_sample
=
202
num_nodes_to_sample
=
202
batch_size
=
32
batch_size
=
32
train_nid
=
th
.
arange
(
num_nodes_to_sample
)
train_nid
=
th
.
arange
(
num_nodes_to_sample
)
dist_graph
=
DistGraph
(
"test_mp"
,
gpb
=
gpb
,
part_config
=
part_config
)
dist_graph
=
DistGraph
(
"test_sampling"
,
for
i
in
range
(
num_server
):
gpb
=
gpb
,
part
,
_
,
_
,
_
,
_
,
_
,
_
=
load_partition
(
part_config
,
i
)
part_config
=
part_config
,
use_graphbolt
=
use_graphbolt
,
)
# Create sampler
# Create sampler
sampler
=
NeighborSampler
(
sampler
=
NeighborSampler
(
dist_graph
,
[
5
,
10
],
dgl
.
distributed
.
sample_neighbors
dist_graph
,
[
5
,
10
],
dgl
.
distributed
.
sample_neighbors
,
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
)
# Enable santity check in distributed sampling.
os
.
environ
[
"DGL_DIST_DEBUG"
]
=
"1"
# We need to test creating DistDataLoader multiple times.
# We need to test creating DistDataLoader multiple times.
for
i
in
range
(
2
):
for
i
in
range
(
2
):
# Create DataLoader for constructing blocks
# Create DataLoader for constructing blocks
...
@@ -113,7 +133,7 @@ def start_dist_dataloader(
...
@@ -113,7 +133,7 @@ def start_dist_dataloader(
groundtruth_g
=
CitationGraphDataset
(
"cora"
)[
0
]
groundtruth_g
=
CitationGraphDataset
(
"cora"
)[
0
]
max_nid
=
[]
max_nid
=
[]
for
epoch
in
range
(
2
):
for
_
in
range
(
2
):
for
idx
,
blocks
in
zip
(
for
idx
,
blocks
in
zip
(
range
(
0
,
num_nodes_to_sample
,
batch_size
),
dataloader
range
(
0
,
num_nodes_to_sample
,
batch_size
),
dataloader
):
):
...
@@ -129,6 +149,16 @@ def start_dist_dataloader(
...
@@ -129,6 +149,16 @@ def start_dist_dataloader(
src_nodes_id
,
dst_nodes_id
src_nodes_id
,
dst_nodes_id
)
)
assert
np
.
all
(
F
.
asnumpy
(
has_edges
))
assert
np
.
all
(
F
.
asnumpy
(
has_edges
))
if
use_graphbolt
and
not
return_eids
:
continue
eids
=
orig_eid
[
block
.
edata
[
dgl
.
EID
]]
expected_eids
=
groundtruth_g
.
edge_ids
(
src_nodes_id
,
dst_nodes_id
)
assert
th
.
equal
(
eids
,
expected_eids
),
f
"
{
eids
}
!=
{
expected_eids
}
"
if
drop_last
:
if
drop_last
:
assert
(
assert
(
np
.
max
(
max_nid
)
np
.
max
(
max_nid
)
...
@@ -311,23 +341,22 @@ def check_neg_dataloader(g, num_server, num_workers):
...
@@ -311,23 +341,22 @@ def check_neg_dataloader(g, num_server, num_workers):
assert
p
.
exitcode
==
0
assert
p
.
exitcode
==
0
@
unittest
.
skip
(
reason
=
"Skip due to glitch in CI"
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
3
])
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
4
])
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
4
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"num_groups"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
def
test_dist_dataloader
(
num_server
,
num_workers
,
drop_last
,
num_groups
):
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
def
test_dist_dataloader
(
num_server
,
num_workers
,
drop_last
,
use_graphbolt
,
return_eids
):
reset_envs
()
reset_envs
()
# No multiple partitions on single machine for
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
# multiple client groups in case of race condition.
os
.
environ
[
"DGL_NUM_SAMPLER"
]
=
str
(
num_workers
)
if
num_groups
>
1
:
num_server
=
1
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
ip_config
=
"ip_config.txt"
ip_config
=
"ip_config.txt"
generate_ip_config
(
ip_config
,
num_server
,
num_server
)
generate_ip_config
(
ip_config
,
num_server
,
num_server
)
g
=
CitationGraphDataset
(
"cora"
)[
0
]
g
=
CitationGraphDataset
(
"cora"
)[
0
]
print
(
g
.
idtype
)
num_parts
=
num_server
num_parts
=
num_server
num_hops
=
1
num_hops
=
1
...
@@ -339,6 +368,8 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups):
...
@@ -339,6 +368,8 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups):
num_hops
=
num_hops
,
num_hops
=
num_hops
,
part_method
=
"metis"
,
part_method
=
"metis"
,
return_mapping
=
True
,
return_mapping
=
True
,
use_graphbolt
=
use_graphbolt
,
store_eids
=
return_eids
,
)
)
part_config
=
os
.
path
.
join
(
test_dir
,
"test_sampling.json"
)
part_config
=
os
.
path
.
join
(
test_dir
,
"test_sampling.json"
)
...
@@ -353,36 +384,33 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups):
...
@@ -353,36 +384,33 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups):
part_config
,
part_config
,
num_server
>
1
,
num_server
>
1
,
num_workers
+
1
,
num_workers
+
1
,
use_graphbolt
,
),
),
)
)
p
.
start
()
p
.
start
()
time
.
sleep
(
1
)
time
.
sleep
(
1
)
pserver_list
.
append
(
p
)
pserver_list
.
append
(
p
)
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
os
.
environ
[
"DGL_NUM_SAMPLER"
]
=
str
(
num_workers
)
ptrainer_list
=
[]
ptrainer_list
=
[]
num_trainers
=
1
num_trainers
=
1
for
trainer_id
in
range
(
num_trainers
):
for
trainer_id
in
range
(
num_trainers
):
for
group_id
in
range
(
num_groups
):
p
=
ctx
.
Process
(
p
=
ctx
.
Process
(
target
=
start_dist_dataloader
,
target
=
start_dist_dataloader
,
args
=
(
args
=
(
trainer_id
,
trainer_id
,
ip_config
,
ip_config
,
part_config
,
part_config
,
num_server
,
num_server
,
drop_last
,
drop_last
,
orig_nid
,
orig_nid
,
orig_eid
,
orig_eid
,
use_graphbolt
,
group_id
,
return_eids
,
),
),
)
)
p
.
start
()
p
.
start
()
time
.
sleep
(
time
.
sleep
(
1
)
# avoid race condition when instantiating DistGraph
1
ptrainer_list
.
append
(
p
)
)
# avoid race condition when instantiating DistGraph
ptrainer_list
.
append
(
p
)
for
p
in
ptrainer_list
:
for
p
in
ptrainer_list
:
p
.
join
()
p
.
join
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment