Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
d1bea9e8
Unverified
Commit
d1bea9e8
authored
Feb 28, 2024
by
Rhett Ying
Committed by
GitHub
Feb 28, 2024
Browse files
[DistGB] format dtype when converting partition to GraphBolt format (#7150)
parent
4091a49c
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
165 additions
and
14 deletions
+165
-14
python/dgl/distributed/dist_graph.py
python/dgl/distributed/dist_graph.py
+6
-1
python/dgl/distributed/graph_partition_book.py
python/dgl/distributed/graph_partition_book.py
+10
-0
python/dgl/distributed/graph_services.py
python/dgl/distributed/graph_services.py
+3
-0
python/dgl/distributed/id_map.py
python/dgl/distributed/id_map.py
+7
-0
python/dgl/distributed/partition.py
python/dgl/distributed/partition.py
+65
-1
tests/distributed/test_distributed_sampling.py
tests/distributed/test_distributed_sampling.py
+37
-7
tests/distributed/test_mp_dataloader.py
tests/distributed/test_mp_dataloader.py
+7
-5
tests/distributed/test_partition.py
tests/distributed/test_partition.py
+30
-0
No files found.
python/dgl/distributed/dist_graph.py
View file @
d1bea9e8
...
@@ -811,7 +811,12 @@ class DistGraph:
...
@@ -811,7 +811,12 @@ class DistGraph:
int
int
"""
"""
# TODO(da?): describe when self._g is None and idtype shouldn't be called.
# TODO(da?): describe when self._g is None and idtype shouldn't be called.
return
F
.
int64
# For GraphBolt partition, we use the global node ID's dtype.
return
(
self
.
get_partition_book
().
global_nid_dtype
if
self
.
_use_graphbolt
else
F
.
int64
)
@
property
@
property
def
device
(
self
):
def
device
(
self
):
...
...
python/dgl/distributed/graph_partition_book.py
View file @
d1bea9e8
...
@@ -945,6 +945,16 @@ class RangePartitionBook(GraphPartitionBook):
...
@@ -945,6 +945,16 @@ class RangePartitionBook(GraphPartitionBook):
)
)
return
ret
return
ret
@
property
def
global_nid_dtype
(
self
):
"""Get the node ID's dtype"""
return
self
.
_nid_map
.
torch_dtype
@
property
def
global_eid_dtype
(
self
):
"""Get the edge ID's dtype"""
return
self
.
_eid_map
.
torch_dtype
NODE_PART_POLICY
=
"node"
NODE_PART_POLICY
=
"node"
EDGE_PART_POLICY
=
"edge"
EDGE_PART_POLICY
=
"edge"
...
...
python/dgl/distributed/graph_services.py
View file @
d1bea9e8
...
@@ -124,6 +124,9 @@ def _sample_neighbors_graphbolt(
...
@@ -124,6 +124,9 @@ def _sample_neighbors_graphbolt(
# 1. Map global node IDs to local node IDs.
# 1. Map global node IDs to local node IDs.
nodes
=
gpb
.
nid2localnid
(
nodes
,
gpb
.
partid
)
nodes
=
gpb
.
nid2localnid
(
nodes
,
gpb
.
partid
)
# Local partition may be saved in torch.int32 even though the global graph
# is in torch.int64.
nodes
=
nodes
.
to
(
dtype
=
g
.
indices
.
dtype
)
# 2. Perform sampling.
# 2. Perform sampling.
# [Rui][TODO] `prob` and `replace` are not tested yet. Skip for now.
# [Rui][TODO] `prob` and `replace` are not tested yet. Skip for now.
...
...
python/dgl/distributed/id_map.py
View file @
d1bea9e8
"""Module for mapping between node/edge IDs and node/edge types."""
"""Module for mapping between node/edge IDs and node/edge types."""
import
numpy
as
np
import
numpy
as
np
import
torch
from
..
import
backend
as
F
,
utils
from
..
import
backend
as
F
,
utils
...
@@ -167,5 +168,11 @@ class IdMap:
...
@@ -167,5 +168,11 @@ class IdMap:
ret
=
utils
.
toindex
(
ret
,
dtype
=
self
.
dtype_str
).
tousertensor
()
ret
=
utils
.
toindex
(
ret
,
dtype
=
self
.
dtype_str
).
tousertensor
()
return
ret
[:
len
(
ids
)],
ret
[
len
(
ids
)
:]
return
ret
[:
len
(
ids
)],
ret
[
len
(
ids
)
:]
@
property
def
torch_dtype
(
self
):
"""Return the data type of the ID map."""
# [TODO][Rui] Use torch instead of numpy.
return
torch
.
int32
if
self
.
dtype
==
np
.
int32
else
torch
.
int64
_init_api
(
"dgl.distributed.id_map"
)
_init_api
(
"dgl.distributed.id_map"
)
python/dgl/distributed/partition.py
View file @
d1bea9e8
...
@@ -1346,6 +1346,34 @@ def partition_graph(
...
@@ -1346,6 +1346,34 @@ def partition_graph(
return
orig_nids
,
orig_eids
return
orig_nids
,
orig_eids
DTYPES_TO_CHECK
=
{
"default"
:
[
torch
.
int32
,
torch
.
int64
],
NID
:
[
torch
.
int32
,
torch
.
int64
],
EID
:
[
torch
.
int32
,
torch
.
int64
],
NTYPE
:
[
torch
.
int8
,
torch
.
int16
,
torch
.
int32
,
torch
.
int64
],
ETYPE
:
[
torch
.
int8
,
torch
.
int16
,
torch
.
int32
,
torch
.
int64
],
"inner_node"
:
[
torch
.
uint8
],
"inner_edge"
:
[
torch
.
uint8
],
"part_id"
:
[
torch
.
int8
,
torch
.
int16
,
torch
.
int32
,
torch
.
int64
],
}
def
_cast_to_minimum_dtype
(
predicate
,
data
,
field
=
None
):
if
data
is
None
:
return
data
dtypes_to_check
=
DTYPES_TO_CHECK
.
get
(
field
,
DTYPES_TO_CHECK
[
"default"
])
if
data
.
dtype
not
in
dtypes_to_check
:
dgl_warning
(
f
"Skipping as the data type of field
{
field
}
is
{
data
.
dtype
}
, "
f
"while supported data types are
{
dtypes_to_check
}
."
)
return
data
for
dtype
in
dtypes_to_check
:
if
predicate
<
torch
.
iinfo
(
dtype
).
max
:
return
data
.
to
(
dtype
)
return
data
def
dgl_partition_to_graphbolt
(
def
dgl_partition_to_graphbolt
(
part_config
,
part_config
,
*
,
*
,
...
@@ -1459,6 +1487,31 @@ def dgl_partition_to_graphbolt(
...
@@ -1459,6 +1487,31 @@ def dgl_partition_to_graphbolt(
attr
:
graph
.
edata
[
attr
][
edge_ids
]
for
attr
in
required_edge_attrs
attr
:
graph
.
edata
[
attr
][
edge_ids
]
for
attr
in
required_edge_attrs
}
}
# Cast various data to minimum dtype.
# Cast 1: indptr.
indptr
=
_cast_to_minimum_dtype
(
graph
.
num_edges
(),
indptr
)
# Cast 2: indices.
indices
=
_cast_to_minimum_dtype
(
graph
.
num_nodes
(),
indices
)
# Cast 3: type_per_edge.
type_per_edge
=
_cast_to_minimum_dtype
(
len
(
etypes
),
type_per_edge
,
field
=
ETYPE
)
# Cast 4: node/edge_attributes.
predicates
=
{
NID
:
part_meta
[
"num_nodes"
],
"part_id"
:
num_parts
,
NTYPE
:
len
(
ntypes
),
EID
:
part_meta
[
"num_edges"
],
ETYPE
:
len
(
etypes
),
}
for
attributes
in
[
node_attributes
,
edge_attributes
]:
for
key
in
attributes
:
if
key
not
in
predicates
:
continue
attributes
[
key
]
=
_cast_to_minimum_dtype
(
predicates
[
key
],
attributes
[
key
],
field
=
key
)
csc_graph
=
gb
.
fused_csc_sampling_graph
(
csc_graph
=
gb
.
fused_csc_sampling_graph
(
indptr
,
indptr
,
indices
,
indices
,
...
@@ -1483,6 +1536,17 @@ def dgl_partition_to_graphbolt(
...
@@ -1483,6 +1536,17 @@ def dgl_partition_to_graphbolt(
"part_graph_graphbolt"
"part_graph_graphbolt"
]
=
os
.
path
.
relpath
(
csc_graph_path
,
os
.
path
.
dirname
(
part_config
))
]
=
os
.
path
.
relpath
(
csc_graph_path
,
os
.
path
.
dirname
(
part_config
))
# Update partition config.
# Save dtype info into partition config.
new_part_meta
[
"node_map_dtype"
]
=
(
"int32"
if
part_meta
[
"num_nodes"
]
<=
torch
.
iinfo
(
torch
.
int32
).
max
else
"int64"
)
new_part_meta
[
"edge_map_dtype"
]
=
(
"int32"
if
part_meta
[
"num_edges"
]
<=
torch
.
iinfo
(
torch
.
int32
).
max
else
"int64"
)
_dump_part_config
(
part_config
,
new_part_meta
)
_dump_part_config
(
part_config
,
new_part_meta
)
print
(
f
"Converted partitions to GraphBolt format into
{
part_config
}
"
)
print
(
f
"Converted partitions to GraphBolt format into
{
part_config
}
"
)
tests/distributed/test_distributed_sampling.py
View file @
d1bea9e8
...
@@ -79,6 +79,7 @@ def start_sample_client_shuffle(
...
@@ -79,6 +79,7 @@ def start_sample_client_shuffle(
orig_eid
,
orig_eid
,
use_graphbolt
=
False
,
use_graphbolt
=
False
,
return_eids
=
False
,
return_eids
=
False
,
node_id_dtype
=
None
,
):
):
os
.
environ
[
"DGL_GROUP_ID"
]
=
str
(
group_id
)
os
.
environ
[
"DGL_GROUP_ID"
]
=
str
(
group_id
)
gpb
=
None
gpb
=
None
...
@@ -90,10 +91,17 @@ def start_sample_client_shuffle(
...
@@ -90,10 +91,17 @@ def start_sample_client_shuffle(
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
)
sampled_graph
=
sample_neighbors
(
sampled_graph
=
sample_neighbors
(
dist_graph
,
dist_graph
,
torch
.
tensor
([
0
,
10
,
99
,
66
,
1024
,
2008
],
dtype
=
dist_graph
.
i
dtype
),
torch
.
tensor
([
0
,
10
,
99
,
66
,
1024
,
2008
],
dtype
=
node_id_
dtype
),
3
,
3
,
use_graphbolt
=
use_graphbolt
,
use_graphbolt
=
use_graphbolt
,
)
)
assert
sampled_graph
.
idtype
==
dist_graph
.
idtype
if
use_graphbolt
:
# dtype conversion is applied for GraphBolt partitions.
assert
sampled_graph
.
idtype
==
torch
.
int32
else
:
# dtype conversion is not applied for non-GraphBolt partitions.
assert
sampled_graph
.
idtype
==
torch
.
int64
assert
(
assert
(
dgl
.
ETYPE
not
in
sampled_graph
.
edata
dgl
.
ETYPE
not
in
sampled_graph
.
edata
...
@@ -399,7 +407,12 @@ def test_rpc_sampling():
...
@@ -399,7 +407,12 @@ def test_rpc_sampling():
def
check_rpc_sampling_shuffle
(
def
check_rpc_sampling_shuffle
(
tmpdir
,
num_server
,
num_groups
=
1
,
use_graphbolt
=
False
,
return_eids
=
False
tmpdir
,
num_server
,
num_groups
=
1
,
use_graphbolt
=
False
,
return_eids
=
False
,
node_id_dtype
=
None
,
):
):
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
...
@@ -454,6 +467,7 @@ def check_rpc_sampling_shuffle(
...
@@ -454,6 +467,7 @@ def check_rpc_sampling_shuffle(
orig_eids
,
orig_eids
,
use_graphbolt
,
use_graphbolt
,
return_eids
,
return_eids
,
node_id_dtype
,
),
),
)
)
p
.
start
()
p
.
start
()
...
@@ -485,6 +499,9 @@ def start_hetero_sample_client(
...
@@ -485,6 +499,9 @@ def start_hetero_sample_client(
assert
"feat"
in
dist_graph
.
nodes
[
"n1"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"n1"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n2"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n2"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n3"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n3"
].
data
nodes
=
{
k
:
torch
.
tensor
(
v
,
dtype
=
dist_graph
.
idtype
)
for
k
,
v
in
nodes
.
items
()
}
if
gpb
is
None
:
if
gpb
is
None
:
gpb
=
dist_graph
.
get_partition_book
()
gpb
=
dist_graph
.
get_partition_book
()
try
:
try
:
...
@@ -523,6 +540,9 @@ def start_hetero_etype_sample_client(
...
@@ -523,6 +540,9 @@ def start_hetero_etype_sample_client(
assert
"feat"
in
dist_graph
.
nodes
[
"n1"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"n1"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n2"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n2"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n3"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n3"
].
data
nodes
=
{
k
:
torch
.
tensor
(
v
,
dtype
=
dist_graph
.
idtype
)
for
k
,
v
in
nodes
.
items
()
}
if
(
not
use_graphbolt
)
and
dist_graph
.
local_partition
is
not
None
:
if
(
not
use_graphbolt
)
and
dist_graph
.
local_partition
is
not
None
:
# Check whether etypes are sorted in dist_graph
# Check whether etypes are sorted in dist_graph
...
@@ -684,7 +704,7 @@ def check_rpc_hetero_sampling_empty_shuffle(
...
@@ -684,7 +704,7 @@ def check_rpc_hetero_sampling_empty_shuffle(
pserver_list
.
append
(
p
)
pserver_list
.
append
(
p
)
deg
=
get_degrees
(
g
,
orig_nids
[
"n3"
],
"n3"
)
deg
=
get_degrees
(
g
,
orig_nids
[
"n3"
],
"n3"
)
empty_nids
=
F
.
nonzero_1d
(
deg
==
0
)
empty_nids
=
F
.
nonzero_1d
(
deg
==
0
)
.
to
(
g
.
idtype
)
block
,
gpb
=
start_hetero_sample_client
(
block
,
gpb
=
start_hetero_sample_client
(
0
,
0
,
tmpdir
,
tmpdir
,
...
@@ -834,7 +854,7 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(
...
@@ -834,7 +854,7 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(
fanout
=
3
fanout
=
3
deg
=
get_degrees
(
g
,
orig_nids
[
"n3"
],
"n3"
)
deg
=
get_degrees
(
g
,
orig_nids
[
"n3"
],
"n3"
)
empty_nids
=
F
.
nonzero_1d
(
deg
==
0
)
empty_nids
=
F
.
nonzero_1d
(
deg
==
0
)
.
to
(
g
.
idtype
)
block
,
gpb
=
start_hetero_etype_sample_client
(
block
,
gpb
=
start_hetero_etype_sample_client
(
0
,
0
,
tmpdir
,
tmpdir
,
...
@@ -881,6 +901,9 @@ def start_bipartite_sample_client(
...
@@ -881,6 +901,9 @@ def start_bipartite_sample_client(
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
)
assert
"feat"
in
dist_graph
.
nodes
[
"user"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"user"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"game"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"game"
].
data
nodes
=
{
k
:
torch
.
tensor
(
v
,
dtype
=
dist_graph
.
idtype
)
for
k
,
v
in
nodes
.
items
()
}
if
gpb
is
None
:
if
gpb
is
None
:
gpb
=
dist_graph
.
get_partition_book
()
gpb
=
dist_graph
.
get_partition_book
()
# Enable santity check in distributed sampling.
# Enable santity check in distributed sampling.
...
@@ -914,6 +937,9 @@ def start_bipartite_etype_sample_client(
...
@@ -914,6 +937,9 @@ def start_bipartite_etype_sample_client(
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
)
assert
"feat"
in
dist_graph
.
nodes
[
"user"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"user"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"game"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"game"
].
data
nodes
=
{
k
:
torch
.
tensor
(
v
,
dtype
=
dist_graph
.
idtype
)
for
k
,
v
in
nodes
.
items
()
}
if
not
use_graphbolt
and
dist_graph
.
local_partition
is
not
None
:
if
not
use_graphbolt
and
dist_graph
.
local_partition
is
not
None
:
# Check whether etypes are sorted in dist_graph
# Check whether etypes are sorted in dist_graph
...
@@ -979,7 +1005,7 @@ def check_rpc_bipartite_sampling_empty(
...
@@ -979,7 +1005,7 @@ def check_rpc_bipartite_sampling_empty(
pserver_list
.
append
(
p
)
pserver_list
.
append
(
p
)
deg
=
get_degrees
(
g
,
orig_nids
[
"game"
],
"game"
)
deg
=
get_degrees
(
g
,
orig_nids
[
"game"
],
"game"
)
empty_nids
=
F
.
nonzero_1d
(
deg
==
0
)
empty_nids
=
F
.
nonzero_1d
(
deg
==
0
)
.
to
(
g
.
idtype
)
nodes
=
{
"game"
:
empty_nids
,
"user"
:
torch
.
tensor
([
1
],
dtype
=
g
.
idtype
)}
nodes
=
{
"game"
:
empty_nids
,
"user"
:
torch
.
tensor
([
1
],
dtype
=
g
.
idtype
)}
block
,
_
=
start_bipartite_sample_client
(
block
,
_
=
start_bipartite_sample_client
(
0
,
0
,
...
@@ -1120,7 +1146,7 @@ def check_rpc_bipartite_etype_sampling_empty(
...
@@ -1120,7 +1146,7 @@ def check_rpc_bipartite_etype_sampling_empty(
pserver_list
.
append
(
p
)
pserver_list
.
append
(
p
)
deg
=
get_degrees
(
g
,
orig_nids
[
"game"
],
"game"
)
deg
=
get_degrees
(
g
,
orig_nids
[
"game"
],
"game"
)
empty_nids
=
F
.
nonzero_1d
(
deg
==
0
)
empty_nids
=
F
.
nonzero_1d
(
deg
==
0
)
.
to
(
g
.
idtype
)
nodes
=
{
"game"
:
empty_nids
,
"user"
:
torch
.
tensor
([
1
],
dtype
=
g
.
idtype
)}
nodes
=
{
"game"
:
empty_nids
,
"user"
:
torch
.
tensor
([
1
],
dtype
=
g
.
idtype
)}
block
,
_
=
start_bipartite_etype_sample_client
(
block
,
_
=
start_bipartite_etype_sample_client
(
0
,
0
,
...
@@ -1225,7 +1251,10 @@ def check_rpc_bipartite_etype_sampling_shuffle(
...
@@ -1225,7 +1251,10 @@ def check_rpc_bipartite_etype_sampling_shuffle(
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
def
test_rpc_sampling_shuffle
(
num_server
,
use_graphbolt
,
return_eids
):
@
pytest
.
mark
.
parametrize
(
"node_id_dtype"
,
[
torch
.
int32
,
torch
.
int64
])
def
test_rpc_sampling_shuffle
(
num_server
,
use_graphbolt
,
return_eids
,
node_id_dtype
):
reset_envs
()
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
...
@@ -1234,6 +1263,7 @@ def test_rpc_sampling_shuffle(num_server, use_graphbolt, return_eids):
...
@@ -1234,6 +1263,7 @@ def test_rpc_sampling_shuffle(num_server, use_graphbolt, return_eids):
num_server
,
num_server
,
use_graphbolt
=
use_graphbolt
,
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
return_eids
=
return_eids
,
node_id_dtype
=
node_id_dtype
,
)
)
...
...
tests/distributed/test_mp_dataloader.py
View file @
d1bea9e8
...
@@ -41,7 +41,7 @@ class NeighborSampler(object):
...
@@ -41,7 +41,7 @@ class NeighborSampler(object):
def
sample_blocks
(
self
,
seeds
):
def
sample_blocks
(
self
,
seeds
):
import
torch
as
th
import
torch
as
th
seeds
=
th
.
LongT
ensor
(
np
.
asarray
(
seeds
))
seeds
=
th
.
t
ensor
(
np
.
asarray
(
seeds
)
,
dtype
=
self
.
g
.
idtype
)
blocks
=
[]
blocks
=
[]
for
fanout
in
self
.
fanouts
:
for
fanout
in
self
.
fanouts
:
# For each seed node, sample ``fanout`` neighbors.
# For each seed node, sample ``fanout`` neighbors.
...
@@ -124,7 +124,7 @@ def start_dist_dataloader(
...
@@ -124,7 +124,7 @@ def start_dist_dataloader(
for
i
in
range
(
2
):
for
i
in
range
(
2
):
# Create DataLoader for constructing blocks
# Create DataLoader for constructing blocks
dataloader
=
DistDataLoader
(
dataloader
=
DistDataLoader
(
dataset
=
train_nid
.
numpy
()
,
dataset
=
train_nid
,
batch_size
=
batch_size
,
batch_size
=
batch_size
,
collate_fn
=
sampler
.
sample_blocks
,
collate_fn
=
sampler
.
sample_blocks
,
shuffle
=
False
,
shuffle
=
False
,
...
@@ -448,9 +448,11 @@ def start_node_dataloader(
...
@@ -448,9 +448,11 @@ def start_node_dataloader(
assert
len
(
dist_graph
.
ntypes
)
==
len
(
groundtruth_g
.
ntypes
)
assert
len
(
dist_graph
.
ntypes
)
==
len
(
groundtruth_g
.
ntypes
)
assert
len
(
dist_graph
.
etypes
)
==
len
(
groundtruth_g
.
etypes
)
assert
len
(
dist_graph
.
etypes
)
==
len
(
groundtruth_g
.
etypes
)
if
len
(
dist_graph
.
etypes
)
==
1
:
if
len
(
dist_graph
.
etypes
)
==
1
:
train_nid
=
th
.
arange
(
num_nodes_to_sample
)
train_nid
=
th
.
arange
(
num_nodes_to_sample
,
dtype
=
dist_graph
.
idtype
)
else
:
else
:
train_nid
=
{
"n3"
:
th
.
arange
(
num_nodes_to_sample
)}
train_nid
=
{
"n3"
:
th
.
arange
(
num_nodes_to_sample
,
dtype
=
dist_graph
.
idtype
)
}
for
i
in
range
(
num_server
):
for
i
in
range
(
num_server
):
part
,
_
,
_
,
_
,
_
,
_
,
_
=
load_partition
(
part_config
,
i
)
part
,
_
,
_
,
_
,
_
,
_
,
_
=
load_partition
(
part_config
,
i
)
...
@@ -765,7 +767,7 @@ def start_multiple_dataloaders(
...
@@ -765,7 +767,7 @@ def start_multiple_dataloaders(
dgl
.
distributed
.
initialize
(
ip_config
)
dgl
.
distributed
.
initialize
(
ip_config
)
dist_g
=
dgl
.
distributed
.
DistGraph
(
graph_name
,
part_config
=
part_config
)
dist_g
=
dgl
.
distributed
.
DistGraph
(
graph_name
,
part_config
=
part_config
)
if
dataloader_type
==
"node"
:
if
dataloader_type
==
"node"
:
train_ids
=
th
.
arange
(
orig_g
.
num_nodes
())
train_ids
=
th
.
arange
(
orig_g
.
num_nodes
()
,
dtype
=
dist_g
.
idtype
)
batch_size
=
orig_g
.
num_nodes
()
//
100
batch_size
=
orig_g
.
num_nodes
()
//
100
else
:
else
:
train_ids
=
th
.
arange
(
orig_g
.
num_edges
())
train_ids
=
th
.
arange
(
orig_g
.
num_edges
())
...
...
tests/distributed/test_partition.py
View file @
d1bea9e8
...
@@ -768,9 +768,18 @@ def test_dgl_partition_to_graphbolt_homo(
...
@@ -768,9 +768,18 @@ def test_dgl_partition_to_graphbolt_homo(
part_config
,
part_id
,
load_feats
=
False
,
use_graphbolt
=
True
part_config
,
part_id
,
load_feats
=
False
,
use_graphbolt
=
True
)[
0
]
)[
0
]
orig_indptr
,
orig_indices
,
orig_eids
=
orig_g
.
adj
().
csc
()
orig_indptr
,
orig_indices
,
orig_eids
=
orig_g
.
adj
().
csc
()
# The original graph is in int64 while the partitioned graph is in
# int32 as dtype formatting is applied when converting to graphbolt
# format.
assert
orig_indptr
.
dtype
==
th
.
int64
assert
orig_indices
.
dtype
==
th
.
int64
assert
new_g
.
csc_indptr
.
dtype
==
th
.
int32
assert
new_g
.
indices
.
dtype
==
th
.
int32
assert
th
.
equal
(
orig_indptr
,
new_g
.
csc_indptr
)
assert
th
.
equal
(
orig_indptr
,
new_g
.
csc_indptr
)
assert
th
.
equal
(
orig_indices
,
new_g
.
indices
)
assert
th
.
equal
(
orig_indices
,
new_g
.
indices
)
assert
new_g
.
node_type_offset
is
None
assert
new_g
.
node_type_offset
is
None
assert
orig_g
.
ndata
[
dgl
.
NID
].
dtype
==
th
.
int64
assert
new_g
.
node_attributes
[
dgl
.
NID
].
dtype
==
th
.
int32
assert
th
.
equal
(
assert
th
.
equal
(
orig_g
.
ndata
[
dgl
.
NID
],
new_g
.
node_attributes
[
dgl
.
NID
]
orig_g
.
ndata
[
dgl
.
NID
],
new_g
.
node_attributes
[
dgl
.
NID
]
)
)
...
@@ -782,6 +791,8 @@ def test_dgl_partition_to_graphbolt_homo(
...
@@ -782,6 +791,8 @@ def test_dgl_partition_to_graphbolt_homo(
else
:
else
:
assert
"inner_node"
not
in
new_g
.
node_attributes
assert
"inner_node"
not
in
new_g
.
node_attributes
if
store_eids
or
debug_mode
:
if
store_eids
or
debug_mode
:
assert
orig_g
.
edata
[
dgl
.
EID
].
dtype
==
th
.
int64
assert
new_g
.
edge_attributes
[
dgl
.
EID
].
dtype
==
th
.
int32
assert
th
.
equal
(
assert
th
.
equal
(
orig_g
.
edata
[
dgl
.
EID
][
orig_eids
],
orig_g
.
edata
[
dgl
.
EID
][
orig_eids
],
new_g
.
edge_attributes
[
dgl
.
EID
],
new_g
.
edge_attributes
[
dgl
.
EID
],
...
@@ -789,6 +800,8 @@ def test_dgl_partition_to_graphbolt_homo(
...
@@ -789,6 +800,8 @@ def test_dgl_partition_to_graphbolt_homo(
else
:
else
:
assert
dgl
.
EID
not
in
new_g
.
edge_attributes
assert
dgl
.
EID
not
in
new_g
.
edge_attributes
if
store_inner_edge
or
debug_mode
:
if
store_inner_edge
or
debug_mode
:
assert
orig_g
.
edata
[
"inner_edge"
].
dtype
==
th
.
uint8
assert
new_g
.
edge_attributes
[
"inner_edge"
].
dtype
==
th
.
uint8
assert
th
.
equal
(
assert
th
.
equal
(
orig_g
.
edata
[
"inner_edge"
][
orig_eids
],
orig_g
.
edata
[
"inner_edge"
][
orig_eids
],
new_g
.
edge_attributes
[
"inner_edge"
],
new_g
.
edge_attributes
[
"inner_edge"
],
...
@@ -838,8 +851,17 @@ def test_dgl_partition_to_graphbolt_hetero(
...
@@ -838,8 +851,17 @@ def test_dgl_partition_to_graphbolt_hetero(
part_config
,
part_id
,
load_feats
=
False
,
use_graphbolt
=
True
part_config
,
part_id
,
load_feats
=
False
,
use_graphbolt
=
True
)[
0
]
)[
0
]
orig_indptr
,
orig_indices
,
orig_eids
=
orig_g
.
adj
().
csc
()
orig_indptr
,
orig_indices
,
orig_eids
=
orig_g
.
adj
().
csc
()
# The original graph is in int64 while the partitioned graph is in
# int32 as dtype formatting is applied when converting to graphbolt
# format.
assert
orig_indptr
.
dtype
==
th
.
int64
assert
orig_indices
.
dtype
==
th
.
int64
assert
new_g
.
csc_indptr
.
dtype
==
th
.
int32
assert
new_g
.
indices
.
dtype
==
th
.
int32
assert
th
.
equal
(
orig_indptr
,
new_g
.
csc_indptr
)
assert
th
.
equal
(
orig_indptr
,
new_g
.
csc_indptr
)
assert
th
.
equal
(
orig_indices
,
new_g
.
indices
)
assert
th
.
equal
(
orig_indices
,
new_g
.
indices
)
assert
orig_g
.
ndata
[
dgl
.
NID
].
dtype
==
th
.
int64
assert
new_g
.
node_attributes
[
dgl
.
NID
].
dtype
==
th
.
int32
assert
th
.
equal
(
assert
th
.
equal
(
orig_g
.
ndata
[
dgl
.
NID
],
new_g
.
node_attributes
[
dgl
.
NID
]
orig_g
.
ndata
[
dgl
.
NID
],
new_g
.
node_attributes
[
dgl
.
NID
]
)
)
...
@@ -851,12 +873,16 @@ def test_dgl_partition_to_graphbolt_hetero(
...
@@ -851,12 +873,16 @@ def test_dgl_partition_to_graphbolt_hetero(
else
:
else
:
assert
"inner_node"
not
in
new_g
.
node_attributes
assert
"inner_node"
not
in
new_g
.
node_attributes
if
debug_mode
:
if
debug_mode
:
assert
orig_g
.
ndata
[
dgl
.
NTYPE
].
dtype
==
th
.
int32
assert
new_g
.
node_attributes
[
dgl
.
NTYPE
].
dtype
==
th
.
int8
assert
th
.
equal
(
assert
th
.
equal
(
orig_g
.
ndata
[
dgl
.
NTYPE
],
new_g
.
node_attributes
[
dgl
.
NTYPE
]
orig_g
.
ndata
[
dgl
.
NTYPE
],
new_g
.
node_attributes
[
dgl
.
NTYPE
]
)
)
else
:
else
:
assert
dgl
.
NTYPE
not
in
new_g
.
node_attributes
assert
dgl
.
NTYPE
not
in
new_g
.
node_attributes
if
store_eids
or
debug_mode
:
if
store_eids
or
debug_mode
:
assert
orig_g
.
edata
[
dgl
.
EID
].
dtype
==
th
.
int64
assert
new_g
.
edge_attributes
[
dgl
.
EID
].
dtype
==
th
.
int32
assert
th
.
equal
(
assert
th
.
equal
(
orig_g
.
edata
[
dgl
.
EID
][
orig_eids
],
orig_g
.
edata
[
dgl
.
EID
][
orig_eids
],
new_g
.
edge_attributes
[
dgl
.
EID
],
new_g
.
edge_attributes
[
dgl
.
EID
],
...
@@ -864,6 +890,8 @@ def test_dgl_partition_to_graphbolt_hetero(
...
@@ -864,6 +890,8 @@ def test_dgl_partition_to_graphbolt_hetero(
else
:
else
:
assert
dgl
.
EID
not
in
new_g
.
edge_attributes
assert
dgl
.
EID
not
in
new_g
.
edge_attributes
if
store_inner_edge
or
debug_mode
:
if
store_inner_edge
or
debug_mode
:
assert
orig_g
.
edata
[
"inner_edge"
].
dtype
==
th
.
uint8
assert
new_g
.
edge_attributes
[
"inner_edge"
].
dtype
==
th
.
uint8
assert
th
.
equal
(
assert
th
.
equal
(
orig_g
.
edata
[
"inner_edge"
],
orig_g
.
edata
[
"inner_edge"
],
new_g
.
edge_attributes
[
"inner_edge"
],
new_g
.
edge_attributes
[
"inner_edge"
],
...
@@ -871,6 +899,8 @@ def test_dgl_partition_to_graphbolt_hetero(
...
@@ -871,6 +899,8 @@ def test_dgl_partition_to_graphbolt_hetero(
else
:
else
:
assert
"inner_edge"
not
in
new_g
.
edge_attributes
assert
"inner_edge"
not
in
new_g
.
edge_attributes
if
debug_mode
:
if
debug_mode
:
assert
orig_g
.
edata
[
dgl
.
ETYPE
].
dtype
==
th
.
int32
assert
new_g
.
edge_attributes
[
dgl
.
ETYPE
].
dtype
==
th
.
int8
assert
th
.
equal
(
assert
th
.
equal
(
orig_g
.
edata
[
dgl
.
ETYPE
][
orig_eids
],
orig_g
.
edata
[
dgl
.
ETYPE
][
orig_eids
],
new_g
.
edge_attributes
[
dgl
.
ETYPE
],
new_g
.
edge_attributes
[
dgl
.
ETYPE
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment