Unverified Commit d8d87243 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] restrict NID/EID as int64_t (#7177)

parent ade806b4
......@@ -1346,10 +1346,12 @@ def partition_graph(
return orig_nids, orig_eids
# [TODO][Rui] Due to int64_t is expected in RPC, we have to limit the data type
# of node/edge IDs to int64_t. See more details in #7175.
DTYPES_TO_CHECK = {
"default": [torch.int32, torch.int64],
NID: [torch.int32, torch.int64],
EID: [torch.int32, torch.int64],
NID: [torch.int64],
EID: [torch.int64],
NTYPE: [torch.int8, torch.int16, torch.int32, torch.int64],
ETYPE: [torch.int8, torch.int16, torch.int32, torch.int64],
"inner_node": [torch.uint8],
......@@ -1537,16 +1539,10 @@ def dgl_partition_to_graphbolt(
] = os.path.relpath(csc_graph_path, os.path.dirname(part_config))
# Save dtype info into partition config.
new_part_meta["node_map_dtype"] = (
"int32"
if part_meta["num_nodes"] <= torch.iinfo(torch.int32).max
else "int64"
)
new_part_meta["edge_map_dtype"] = (
"int32"
if part_meta["num_edges"] <= torch.iinfo(torch.int32).max
else "int64"
)
# [TODO][Rui] Always use int64_t for node/edge IDs in GraphBolt. See more
# details in #7175.
new_part_meta["node_map_dtype"] = "int64"
new_part_meta["edge_map_dtype"] = "int64"
_dump_part_config(part_config, new_part_meta)
print(f"Converted partitions to GraphBolt format into {part_config}")
......@@ -96,12 +96,7 @@ def start_sample_client_shuffle(
use_graphbolt=use_graphbolt,
)
assert sampled_graph.idtype == dist_graph.idtype
if use_graphbolt:
# dtype conversion is applied for GraphBolt partitions.
assert sampled_graph.idtype == torch.int32
else:
# dtype conversion is not applied for non-GraphBolt partitions.
assert sampled_graph.idtype == torch.int64
assert sampled_graph.idtype == torch.int64
assert (
dgl.ETYPE not in sampled_graph.edata
......@@ -1251,7 +1246,7 @@ def check_rpc_bipartite_etype_sampling_shuffle(
@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
@pytest.mark.parametrize("node_id_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("node_id_dtype", [torch.int64])
def test_rpc_sampling_shuffle(
num_server, use_graphbolt, return_eids, node_id_dtype
):
......
......@@ -779,7 +779,7 @@ def test_dgl_partition_to_graphbolt_homo(
assert th.equal(orig_indices, new_g.indices)
assert new_g.node_type_offset is None
assert orig_g.ndata[dgl.NID].dtype == th.int64
assert new_g.node_attributes[dgl.NID].dtype == th.int32
assert new_g.node_attributes[dgl.NID].dtype == th.int64
assert th.equal(
orig_g.ndata[dgl.NID], new_g.node_attributes[dgl.NID]
)
......@@ -792,7 +792,7 @@ def test_dgl_partition_to_graphbolt_homo(
assert "inner_node" not in new_g.node_attributes
if store_eids or debug_mode:
assert orig_g.edata[dgl.EID].dtype == th.int64
assert new_g.edge_attributes[dgl.EID].dtype == th.int32
assert new_g.edge_attributes[dgl.EID].dtype == th.int64
assert th.equal(
orig_g.edata[dgl.EID][orig_eids],
new_g.edge_attributes[dgl.EID],
......@@ -861,7 +861,7 @@ def test_dgl_partition_to_graphbolt_hetero(
assert th.equal(orig_indptr, new_g.csc_indptr)
assert th.equal(orig_indices, new_g.indices)
assert orig_g.ndata[dgl.NID].dtype == th.int64
assert new_g.node_attributes[dgl.NID].dtype == th.int32
assert new_g.node_attributes[dgl.NID].dtype == th.int64
assert th.equal(
orig_g.ndata[dgl.NID], new_g.node_attributes[dgl.NID]
)
......@@ -882,7 +882,7 @@ def test_dgl_partition_to_graphbolt_hetero(
assert dgl.NTYPE not in new_g.node_attributes
if store_eids or debug_mode:
assert orig_g.edata[dgl.EID].dtype == th.int64
assert new_g.edge_attributes[dgl.EID].dtype == th.int32
assert new_g.edge_attributes[dgl.EID].dtype == th.int64
assert th.equal(
orig_g.edata[dgl.EID][orig_eids],
new_g.edge_attributes[dgl.EID],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment