"docs/source/reference/launcher.md" did not exist on "532146338bfcc6af86efbe61825206a9e913f37f"
Unverified Commit d8d87243 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] restrict NID/EID as int64_t (#7177)

parent ade806b4
...@@ -1346,10 +1346,12 @@ def partition_graph( ...@@ -1346,10 +1346,12 @@ def partition_graph(
return orig_nids, orig_eids return orig_nids, orig_eids
# [TODO][Rui] Due to int64_t is expected in RPC, we have to limit the data type
# of node/edge IDs to int64_t. See more details in #7175.
DTYPES_TO_CHECK = { DTYPES_TO_CHECK = {
"default": [torch.int32, torch.int64], "default": [torch.int32, torch.int64],
NID: [torch.int32, torch.int64], NID: [torch.int64],
EID: [torch.int32, torch.int64], EID: [torch.int64],
NTYPE: [torch.int8, torch.int16, torch.int32, torch.int64], NTYPE: [torch.int8, torch.int16, torch.int32, torch.int64],
ETYPE: [torch.int8, torch.int16, torch.int32, torch.int64], ETYPE: [torch.int8, torch.int16, torch.int32, torch.int64],
"inner_node": [torch.uint8], "inner_node": [torch.uint8],
...@@ -1537,16 +1539,10 @@ def dgl_partition_to_graphbolt( ...@@ -1537,16 +1539,10 @@ def dgl_partition_to_graphbolt(
] = os.path.relpath(csc_graph_path, os.path.dirname(part_config)) ] = os.path.relpath(csc_graph_path, os.path.dirname(part_config))
# Save dtype info into partition config. # Save dtype info into partition config.
new_part_meta["node_map_dtype"] = ( # [TODO][Rui] Always use int64_t for node/edge IDs in GraphBolt. See more
"int32" # details in #7175.
if part_meta["num_nodes"] <= torch.iinfo(torch.int32).max new_part_meta["node_map_dtype"] = "int64"
else "int64" new_part_meta["edge_map_dtype"] = "int64"
)
new_part_meta["edge_map_dtype"] = (
"int32"
if part_meta["num_edges"] <= torch.iinfo(torch.int32).max
else "int64"
)
_dump_part_config(part_config, new_part_meta) _dump_part_config(part_config, new_part_meta)
print(f"Converted partitions to GraphBolt format into {part_config}") print(f"Converted partitions to GraphBolt format into {part_config}")
...@@ -96,11 +96,6 @@ def start_sample_client_shuffle( ...@@ -96,11 +96,6 @@ def start_sample_client_shuffle(
use_graphbolt=use_graphbolt, use_graphbolt=use_graphbolt,
) )
assert sampled_graph.idtype == dist_graph.idtype assert sampled_graph.idtype == dist_graph.idtype
if use_graphbolt:
# dtype conversion is applied for GraphBolt partitions.
assert sampled_graph.idtype == torch.int32
else:
# dtype conversion is not applied for non-GraphBolt partitions.
assert sampled_graph.idtype == torch.int64 assert sampled_graph.idtype == torch.int64
assert ( assert (
...@@ -1251,7 +1246,7 @@ def check_rpc_bipartite_etype_sampling_shuffle( ...@@ -1251,7 +1246,7 @@ def check_rpc_bipartite_etype_sampling_shuffle(
@pytest.mark.parametrize("num_server", [1]) @pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("use_graphbolt", [False, True]) @pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True]) @pytest.mark.parametrize("return_eids", [False, True])
@pytest.mark.parametrize("node_id_dtype", [torch.int32, torch.int64]) @pytest.mark.parametrize("node_id_dtype", [torch.int64])
def test_rpc_sampling_shuffle( def test_rpc_sampling_shuffle(
num_server, use_graphbolt, return_eids, node_id_dtype num_server, use_graphbolt, return_eids, node_id_dtype
): ):
......
...@@ -779,7 +779,7 @@ def test_dgl_partition_to_graphbolt_homo( ...@@ -779,7 +779,7 @@ def test_dgl_partition_to_graphbolt_homo(
assert th.equal(orig_indices, new_g.indices) assert th.equal(orig_indices, new_g.indices)
assert new_g.node_type_offset is None assert new_g.node_type_offset is None
assert orig_g.ndata[dgl.NID].dtype == th.int64 assert orig_g.ndata[dgl.NID].dtype == th.int64
assert new_g.node_attributes[dgl.NID].dtype == th.int32 assert new_g.node_attributes[dgl.NID].dtype == th.int64
assert th.equal( assert th.equal(
orig_g.ndata[dgl.NID], new_g.node_attributes[dgl.NID] orig_g.ndata[dgl.NID], new_g.node_attributes[dgl.NID]
) )
...@@ -792,7 +792,7 @@ def test_dgl_partition_to_graphbolt_homo( ...@@ -792,7 +792,7 @@ def test_dgl_partition_to_graphbolt_homo(
assert "inner_node" not in new_g.node_attributes assert "inner_node" not in new_g.node_attributes
if store_eids or debug_mode: if store_eids or debug_mode:
assert orig_g.edata[dgl.EID].dtype == th.int64 assert orig_g.edata[dgl.EID].dtype == th.int64
assert new_g.edge_attributes[dgl.EID].dtype == th.int32 assert new_g.edge_attributes[dgl.EID].dtype == th.int64
assert th.equal( assert th.equal(
orig_g.edata[dgl.EID][orig_eids], orig_g.edata[dgl.EID][orig_eids],
new_g.edge_attributes[dgl.EID], new_g.edge_attributes[dgl.EID],
...@@ -861,7 +861,7 @@ def test_dgl_partition_to_graphbolt_hetero( ...@@ -861,7 +861,7 @@ def test_dgl_partition_to_graphbolt_hetero(
assert th.equal(orig_indptr, new_g.csc_indptr) assert th.equal(orig_indptr, new_g.csc_indptr)
assert th.equal(orig_indices, new_g.indices) assert th.equal(orig_indices, new_g.indices)
assert orig_g.ndata[dgl.NID].dtype == th.int64 assert orig_g.ndata[dgl.NID].dtype == th.int64
assert new_g.node_attributes[dgl.NID].dtype == th.int32 assert new_g.node_attributes[dgl.NID].dtype == th.int64
assert th.equal( assert th.equal(
orig_g.ndata[dgl.NID], new_g.node_attributes[dgl.NID] orig_g.ndata[dgl.NID], new_g.node_attributes[dgl.NID]
) )
...@@ -882,7 +882,7 @@ def test_dgl_partition_to_graphbolt_hetero( ...@@ -882,7 +882,7 @@ def test_dgl_partition_to_graphbolt_hetero(
assert dgl.NTYPE not in new_g.node_attributes assert dgl.NTYPE not in new_g.node_attributes
if store_eids or debug_mode: if store_eids or debug_mode:
assert orig_g.edata[dgl.EID].dtype == th.int64 assert orig_g.edata[dgl.EID].dtype == th.int64
assert new_g.edge_attributes[dgl.EID].dtype == th.int32 assert new_g.edge_attributes[dgl.EID].dtype == th.int64
assert th.equal( assert th.equal(
orig_g.edata[dgl.EID][orig_eids], orig_g.edata[dgl.EID][orig_eids],
new_g.edge_attributes[dgl.EID], new_g.edge_attributes[dgl.EID],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment