Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
d8d87243
Unverified
Commit
d8d87243
authored
Feb 29, 2024
by
Rhett Ying
Committed by
GitHub
Feb 29, 2024
Browse files
[DistGB] restrict NID/EID as int64_t (#7177)
parent
ade806b4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
23 deletions
+14
-23
python/dgl/distributed/partition.py
python/dgl/distributed/partition.py
+8
-12
tests/distributed/test_distributed_sampling.py
tests/distributed/test_distributed_sampling.py
+2
-7
tests/distributed/test_partition.py
tests/distributed/test_partition.py
+4
-4
No files found.
python/dgl/distributed/partition.py
View file @
d8d87243
...
@@ -1346,10 +1346,12 @@ def partition_graph(
...
@@ -1346,10 +1346,12 @@ def partition_graph(
return
orig_nids
,
orig_eids
return
orig_nids
,
orig_eids
# [TODO][Rui] Due to int64_t is expected in RPC, we have to limit the data type
# of node/edge IDs to int64_t. See more details in #7175.
DTYPES_TO_CHECK
=
{
DTYPES_TO_CHECK
=
{
"default"
:
[
torch
.
int32
,
torch
.
int64
],
"default"
:
[
torch
.
int32
,
torch
.
int64
],
NID
:
[
torch
.
int32
,
torch
.
int64
],
NID
:
[
torch
.
int64
],
EID
:
[
torch
.
int32
,
torch
.
int64
],
EID
:
[
torch
.
int64
],
NTYPE
:
[
torch
.
int8
,
torch
.
int16
,
torch
.
int32
,
torch
.
int64
],
NTYPE
:
[
torch
.
int8
,
torch
.
int16
,
torch
.
int32
,
torch
.
int64
],
ETYPE
:
[
torch
.
int8
,
torch
.
int16
,
torch
.
int32
,
torch
.
int64
],
ETYPE
:
[
torch
.
int8
,
torch
.
int16
,
torch
.
int32
,
torch
.
int64
],
"inner_node"
:
[
torch
.
uint8
],
"inner_node"
:
[
torch
.
uint8
],
...
@@ -1537,16 +1539,10 @@ def dgl_partition_to_graphbolt(
...
@@ -1537,16 +1539,10 @@ def dgl_partition_to_graphbolt(
]
=
os
.
path
.
relpath
(
csc_graph_path
,
os
.
path
.
dirname
(
part_config
))
]
=
os
.
path
.
relpath
(
csc_graph_path
,
os
.
path
.
dirname
(
part_config
))
# Save dtype info into partition config.
# Save dtype info into partition config.
new_part_meta
[
"node_map_dtype"
]
=
(
# [TODO][Rui] Always use int64_t for node/edge IDs in GraphBolt. See more
"int32"
# details in #7175.
if
part_meta
[
"num_nodes"
]
<=
torch
.
iinfo
(
torch
.
int32
).
max
new_part_meta
[
"node_map_dtype"
]
=
"int64"
else
"int64"
new_part_meta
[
"edge_map_dtype"
]
=
"int64"
)
new_part_meta
[
"edge_map_dtype"
]
=
(
"int32"
if
part_meta
[
"num_edges"
]
<=
torch
.
iinfo
(
torch
.
int32
).
max
else
"int64"
)
_dump_part_config
(
part_config
,
new_part_meta
)
_dump_part_config
(
part_config
,
new_part_meta
)
print
(
f
"Converted partitions to GraphBolt format into
{
part_config
}
"
)
print
(
f
"Converted partitions to GraphBolt format into
{
part_config
}
"
)
tests/distributed/test_distributed_sampling.py
View file @
d8d87243
...
@@ -96,12 +96,7 @@ def start_sample_client_shuffle(
...
@@ -96,12 +96,7 @@ def start_sample_client_shuffle(
use_graphbolt
=
use_graphbolt
,
use_graphbolt
=
use_graphbolt
,
)
)
assert
sampled_graph
.
idtype
==
dist_graph
.
idtype
assert
sampled_graph
.
idtype
==
dist_graph
.
idtype
if
use_graphbolt
:
assert
sampled_graph
.
idtype
==
torch
.
int64
# dtype conversion is applied for GraphBolt partitions.
assert
sampled_graph
.
idtype
==
torch
.
int32
else
:
# dtype conversion is not applied for non-GraphBolt partitions.
assert
sampled_graph
.
idtype
==
torch
.
int64
assert
(
assert
(
dgl
.
ETYPE
not
in
sampled_graph
.
edata
dgl
.
ETYPE
not
in
sampled_graph
.
edata
...
@@ -1251,7 +1246,7 @@ def check_rpc_bipartite_etype_sampling_shuffle(
...
@@ -1251,7 +1246,7 @@ def check_rpc_bipartite_etype_sampling_shuffle(
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"node_id_dtype"
,
[
torch
.
int32
,
torch
.
int64
])
@
pytest
.
mark
.
parametrize
(
"node_id_dtype"
,
[
torch
.
int64
])
def
test_rpc_sampling_shuffle
(
def
test_rpc_sampling_shuffle
(
num_server
,
use_graphbolt
,
return_eids
,
node_id_dtype
num_server
,
use_graphbolt
,
return_eids
,
node_id_dtype
):
):
...
...
tests/distributed/test_partition.py
View file @
d8d87243
...
@@ -779,7 +779,7 @@ def test_dgl_partition_to_graphbolt_homo(
...
@@ -779,7 +779,7 @@ def test_dgl_partition_to_graphbolt_homo(
assert
th
.
equal
(
orig_indices
,
new_g
.
indices
)
assert
th
.
equal
(
orig_indices
,
new_g
.
indices
)
assert
new_g
.
node_type_offset
is
None
assert
new_g
.
node_type_offset
is
None
assert
orig_g
.
ndata
[
dgl
.
NID
].
dtype
==
th
.
int64
assert
orig_g
.
ndata
[
dgl
.
NID
].
dtype
==
th
.
int64
assert
new_g
.
node_attributes
[
dgl
.
NID
].
dtype
==
th
.
int
32
assert
new_g
.
node_attributes
[
dgl
.
NID
].
dtype
==
th
.
int
64
assert
th
.
equal
(
assert
th
.
equal
(
orig_g
.
ndata
[
dgl
.
NID
],
new_g
.
node_attributes
[
dgl
.
NID
]
orig_g
.
ndata
[
dgl
.
NID
],
new_g
.
node_attributes
[
dgl
.
NID
]
)
)
...
@@ -792,7 +792,7 @@ def test_dgl_partition_to_graphbolt_homo(
...
@@ -792,7 +792,7 @@ def test_dgl_partition_to_graphbolt_homo(
assert
"inner_node"
not
in
new_g
.
node_attributes
assert
"inner_node"
not
in
new_g
.
node_attributes
if
store_eids
or
debug_mode
:
if
store_eids
or
debug_mode
:
assert
orig_g
.
edata
[
dgl
.
EID
].
dtype
==
th
.
int64
assert
orig_g
.
edata
[
dgl
.
EID
].
dtype
==
th
.
int64
assert
new_g
.
edge_attributes
[
dgl
.
EID
].
dtype
==
th
.
int
32
assert
new_g
.
edge_attributes
[
dgl
.
EID
].
dtype
==
th
.
int
64
assert
th
.
equal
(
assert
th
.
equal
(
orig_g
.
edata
[
dgl
.
EID
][
orig_eids
],
orig_g
.
edata
[
dgl
.
EID
][
orig_eids
],
new_g
.
edge_attributes
[
dgl
.
EID
],
new_g
.
edge_attributes
[
dgl
.
EID
],
...
@@ -861,7 +861,7 @@ def test_dgl_partition_to_graphbolt_hetero(
...
@@ -861,7 +861,7 @@ def test_dgl_partition_to_graphbolt_hetero(
assert
th
.
equal
(
orig_indptr
,
new_g
.
csc_indptr
)
assert
th
.
equal
(
orig_indptr
,
new_g
.
csc_indptr
)
assert
th
.
equal
(
orig_indices
,
new_g
.
indices
)
assert
th
.
equal
(
orig_indices
,
new_g
.
indices
)
assert
orig_g
.
ndata
[
dgl
.
NID
].
dtype
==
th
.
int64
assert
orig_g
.
ndata
[
dgl
.
NID
].
dtype
==
th
.
int64
assert
new_g
.
node_attributes
[
dgl
.
NID
].
dtype
==
th
.
int
32
assert
new_g
.
node_attributes
[
dgl
.
NID
].
dtype
==
th
.
int
64
assert
th
.
equal
(
assert
th
.
equal
(
orig_g
.
ndata
[
dgl
.
NID
],
new_g
.
node_attributes
[
dgl
.
NID
]
orig_g
.
ndata
[
dgl
.
NID
],
new_g
.
node_attributes
[
dgl
.
NID
]
)
)
...
@@ -882,7 +882,7 @@ def test_dgl_partition_to_graphbolt_hetero(
...
@@ -882,7 +882,7 @@ def test_dgl_partition_to_graphbolt_hetero(
assert
dgl
.
NTYPE
not
in
new_g
.
node_attributes
assert
dgl
.
NTYPE
not
in
new_g
.
node_attributes
if
store_eids
or
debug_mode
:
if
store_eids
or
debug_mode
:
assert
orig_g
.
edata
[
dgl
.
EID
].
dtype
==
th
.
int64
assert
orig_g
.
edata
[
dgl
.
EID
].
dtype
==
th
.
int64
assert
new_g
.
edge_attributes
[
dgl
.
EID
].
dtype
==
th
.
int
32
assert
new_g
.
edge_attributes
[
dgl
.
EID
].
dtype
==
th
.
int
64
assert
th
.
equal
(
assert
th
.
equal
(
orig_g
.
edata
[
dgl
.
EID
][
orig_eids
],
orig_g
.
edata
[
dgl
.
EID
][
orig_eids
],
new_g
.
edge_attributes
[
dgl
.
EID
],
new_g
.
edge_attributes
[
dgl
.
EID
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment