Unverified Commit d1bea9e8 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] format dtype when converting partition to GraphBolt format (#7150)

parent 4091a49c
...@@ -811,7 +811,12 @@ class DistGraph: ...@@ -811,7 +811,12 @@ class DistGraph:
int int
""" """
# TODO(da?): describe when self._g is None and idtype shouldn't be called. # TODO(da?): describe when self._g is None and idtype shouldn't be called.
return F.int64 # For GraphBolt partition, we use the global node ID's dtype.
return (
self.get_partition_book().global_nid_dtype
if self._use_graphbolt
else F.int64
)
@property @property
def device(self): def device(self):
......
...@@ -945,6 +945,16 @@ class RangePartitionBook(GraphPartitionBook): ...@@ -945,6 +945,16 @@ class RangePartitionBook(GraphPartitionBook):
) )
return ret return ret
@property
def global_nid_dtype(self):
"""Get the node ID's dtype"""
return self._nid_map.torch_dtype
@property
def global_eid_dtype(self):
"""Get the edge ID's dtype"""
return self._eid_map.torch_dtype
NODE_PART_POLICY = "node" NODE_PART_POLICY = "node"
EDGE_PART_POLICY = "edge" EDGE_PART_POLICY = "edge"
......
...@@ -124,6 +124,9 @@ def _sample_neighbors_graphbolt( ...@@ -124,6 +124,9 @@ def _sample_neighbors_graphbolt(
# 1. Map global node IDs to local node IDs. # 1. Map global node IDs to local node IDs.
nodes = gpb.nid2localnid(nodes, gpb.partid) nodes = gpb.nid2localnid(nodes, gpb.partid)
# Local partition may be saved in torch.int32 even though the global graph
# is in torch.int64.
nodes = nodes.to(dtype=g.indices.dtype)
# 2. Perform sampling. # 2. Perform sampling.
# [Rui][TODO] `prob` and `replace` are not tested yet. Skip for now. # [Rui][TODO] `prob` and `replace` are not tested yet. Skip for now.
......
"""Module for mapping between node/edge IDs and node/edge types.""" """Module for mapping between node/edge IDs and node/edge types."""
import numpy as np import numpy as np
import torch
from .. import backend as F, utils from .. import backend as F, utils
...@@ -167,5 +168,11 @@ class IdMap: ...@@ -167,5 +168,11 @@ class IdMap:
ret = utils.toindex(ret, dtype=self.dtype_str).tousertensor() ret = utils.toindex(ret, dtype=self.dtype_str).tousertensor()
return ret[: len(ids)], ret[len(ids) :] return ret[: len(ids)], ret[len(ids) :]
@property
def torch_dtype(self):
"""Return the data type of the ID map."""
# [TODO][Rui] Use torch instead of numpy.
return torch.int32 if self.dtype == np.int32 else torch.int64
_init_api("dgl.distributed.id_map") _init_api("dgl.distributed.id_map")
...@@ -1346,6 +1346,34 @@ def partition_graph( ...@@ -1346,6 +1346,34 @@ def partition_graph(
return orig_nids, orig_eids return orig_nids, orig_eids
DTYPES_TO_CHECK = {
"default": [torch.int32, torch.int64],
NID: [torch.int32, torch.int64],
EID: [torch.int32, torch.int64],
NTYPE: [torch.int8, torch.int16, torch.int32, torch.int64],
ETYPE: [torch.int8, torch.int16, torch.int32, torch.int64],
"inner_node": [torch.uint8],
"inner_edge": [torch.uint8],
"part_id": [torch.int8, torch.int16, torch.int32, torch.int64],
}
def _cast_to_minimum_dtype(predicate, data, field=None):
if data is None:
return data
dtypes_to_check = DTYPES_TO_CHECK.get(field, DTYPES_TO_CHECK["default"])
if data.dtype not in dtypes_to_check:
dgl_warning(
f"Skipping as the data type of field {field} is {data.dtype}, "
f"while supported data types are {dtypes_to_check}."
)
return data
for dtype in dtypes_to_check:
if predicate < torch.iinfo(dtype).max:
return data.to(dtype)
return data
def dgl_partition_to_graphbolt( def dgl_partition_to_graphbolt(
part_config, part_config,
*, *,
...@@ -1459,6 +1487,31 @@ def dgl_partition_to_graphbolt( ...@@ -1459,6 +1487,31 @@ def dgl_partition_to_graphbolt(
attr: graph.edata[attr][edge_ids] for attr in required_edge_attrs attr: graph.edata[attr][edge_ids] for attr in required_edge_attrs
} }
# Cast various data to minimum dtype.
# Cast 1: indptr.
indptr = _cast_to_minimum_dtype(graph.num_edges(), indptr)
# Cast 2: indices.
indices = _cast_to_minimum_dtype(graph.num_nodes(), indices)
# Cast 3: type_per_edge.
type_per_edge = _cast_to_minimum_dtype(
len(etypes), type_per_edge, field=ETYPE
)
# Cast 4: node/edge_attributes.
predicates = {
NID: part_meta["num_nodes"],
"part_id": num_parts,
NTYPE: len(ntypes),
EID: part_meta["num_edges"],
ETYPE: len(etypes),
}
for attributes in [node_attributes, edge_attributes]:
for key in attributes:
if key not in predicates:
continue
attributes[key] = _cast_to_minimum_dtype(
predicates[key], attributes[key], field=key
)
csc_graph = gb.fused_csc_sampling_graph( csc_graph = gb.fused_csc_sampling_graph(
indptr, indptr,
indices, indices,
...@@ -1483,6 +1536,17 @@ def dgl_partition_to_graphbolt( ...@@ -1483,6 +1536,17 @@ def dgl_partition_to_graphbolt(
"part_graph_graphbolt" "part_graph_graphbolt"
] = os.path.relpath(csc_graph_path, os.path.dirname(part_config)) ] = os.path.relpath(csc_graph_path, os.path.dirname(part_config))
# Update partition config. # Save dtype info into partition config.
new_part_meta["node_map_dtype"] = (
"int32"
if part_meta["num_nodes"] <= torch.iinfo(torch.int32).max
else "int64"
)
new_part_meta["edge_map_dtype"] = (
"int32"
if part_meta["num_edges"] <= torch.iinfo(torch.int32).max
else "int64"
)
_dump_part_config(part_config, new_part_meta) _dump_part_config(part_config, new_part_meta)
print(f"Converted partitions to GraphBolt format into {part_config}") print(f"Converted partitions to GraphBolt format into {part_config}")
...@@ -79,6 +79,7 @@ def start_sample_client_shuffle( ...@@ -79,6 +79,7 @@ def start_sample_client_shuffle(
orig_eid, orig_eid,
use_graphbolt=False, use_graphbolt=False,
return_eids=False, return_eids=False,
node_id_dtype=None,
): ):
os.environ["DGL_GROUP_ID"] = str(group_id) os.environ["DGL_GROUP_ID"] = str(group_id)
gpb = None gpb = None
...@@ -90,10 +91,17 @@ def start_sample_client_shuffle( ...@@ -90,10 +91,17 @@ def start_sample_client_shuffle(
dist_graph = DistGraph("test_sampling", gpb=gpb) dist_graph = DistGraph("test_sampling", gpb=gpb)
sampled_graph = sample_neighbors( sampled_graph = sample_neighbors(
dist_graph, dist_graph,
torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype), torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=node_id_dtype),
3, 3,
use_graphbolt=use_graphbolt, use_graphbolt=use_graphbolt,
) )
assert sampled_graph.idtype == dist_graph.idtype
if use_graphbolt:
# dtype conversion is applied for GraphBolt partitions.
assert sampled_graph.idtype == torch.int32
else:
# dtype conversion is not applied for non-GraphBolt partitions.
assert sampled_graph.idtype == torch.int64
assert ( assert (
dgl.ETYPE not in sampled_graph.edata dgl.ETYPE not in sampled_graph.edata
...@@ -399,7 +407,12 @@ def test_rpc_sampling(): ...@@ -399,7 +407,12 @@ def test_rpc_sampling():
def check_rpc_sampling_shuffle( def check_rpc_sampling_shuffle(
tmpdir, num_server, num_groups=1, use_graphbolt=False, return_eids=False tmpdir,
num_server,
num_groups=1,
use_graphbolt=False,
return_eids=False,
node_id_dtype=None,
): ):
generate_ip_config("rpc_ip_config.txt", num_server, num_server) generate_ip_config("rpc_ip_config.txt", num_server, num_server)
...@@ -454,6 +467,7 @@ def check_rpc_sampling_shuffle( ...@@ -454,6 +467,7 @@ def check_rpc_sampling_shuffle(
orig_eids, orig_eids,
use_graphbolt, use_graphbolt,
return_eids, return_eids,
node_id_dtype,
), ),
) )
p.start() p.start()
...@@ -485,6 +499,9 @@ def start_hetero_sample_client( ...@@ -485,6 +499,9 @@ def start_hetero_sample_client(
assert "feat" in dist_graph.nodes["n1"].data assert "feat" in dist_graph.nodes["n1"].data
assert "feat" not in dist_graph.nodes["n2"].data assert "feat" not in dist_graph.nodes["n2"].data
assert "feat" not in dist_graph.nodes["n3"].data assert "feat" not in dist_graph.nodes["n3"].data
nodes = {
k: torch.tensor(v, dtype=dist_graph.idtype) for k, v in nodes.items()
}
if gpb is None: if gpb is None:
gpb = dist_graph.get_partition_book() gpb = dist_graph.get_partition_book()
try: try:
...@@ -523,6 +540,9 @@ def start_hetero_etype_sample_client( ...@@ -523,6 +540,9 @@ def start_hetero_etype_sample_client(
assert "feat" in dist_graph.nodes["n1"].data assert "feat" in dist_graph.nodes["n1"].data
assert "feat" not in dist_graph.nodes["n2"].data assert "feat" not in dist_graph.nodes["n2"].data
assert "feat" not in dist_graph.nodes["n3"].data assert "feat" not in dist_graph.nodes["n3"].data
nodes = {
k: torch.tensor(v, dtype=dist_graph.idtype) for k, v in nodes.items()
}
if (not use_graphbolt) and dist_graph.local_partition is not None: if (not use_graphbolt) and dist_graph.local_partition is not None:
# Check whether etypes are sorted in dist_graph # Check whether etypes are sorted in dist_graph
...@@ -684,7 +704,7 @@ def check_rpc_hetero_sampling_empty_shuffle( ...@@ -684,7 +704,7 @@ def check_rpc_hetero_sampling_empty_shuffle(
pserver_list.append(p) pserver_list.append(p)
deg = get_degrees(g, orig_nids["n3"], "n3") deg = get_degrees(g, orig_nids["n3"], "n3")
empty_nids = F.nonzero_1d(deg == 0) empty_nids = F.nonzero_1d(deg == 0).to(g.idtype)
block, gpb = start_hetero_sample_client( block, gpb = start_hetero_sample_client(
0, 0,
tmpdir, tmpdir,
...@@ -834,7 +854,7 @@ def check_rpc_hetero_etype_sampling_empty_shuffle( ...@@ -834,7 +854,7 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(
fanout = 3 fanout = 3
deg = get_degrees(g, orig_nids["n3"], "n3") deg = get_degrees(g, orig_nids["n3"], "n3")
empty_nids = F.nonzero_1d(deg == 0) empty_nids = F.nonzero_1d(deg == 0).to(g.idtype)
block, gpb = start_hetero_etype_sample_client( block, gpb = start_hetero_etype_sample_client(
0, 0,
tmpdir, tmpdir,
...@@ -881,6 +901,9 @@ def start_bipartite_sample_client( ...@@ -881,6 +901,9 @@ def start_bipartite_sample_client(
dist_graph = DistGraph("test_sampling", gpb=gpb) dist_graph = DistGraph("test_sampling", gpb=gpb)
assert "feat" in dist_graph.nodes["user"].data assert "feat" in dist_graph.nodes["user"].data
assert "feat" in dist_graph.nodes["game"].data assert "feat" in dist_graph.nodes["game"].data
nodes = {
k: torch.tensor(v, dtype=dist_graph.idtype) for k, v in nodes.items()
}
if gpb is None: if gpb is None:
gpb = dist_graph.get_partition_book() gpb = dist_graph.get_partition_book()
# Enable santity check in distributed sampling. # Enable santity check in distributed sampling.
...@@ -914,6 +937,9 @@ def start_bipartite_etype_sample_client( ...@@ -914,6 +937,9 @@ def start_bipartite_etype_sample_client(
dist_graph = DistGraph("test_sampling", gpb=gpb) dist_graph = DistGraph("test_sampling", gpb=gpb)
assert "feat" in dist_graph.nodes["user"].data assert "feat" in dist_graph.nodes["user"].data
assert "feat" in dist_graph.nodes["game"].data assert "feat" in dist_graph.nodes["game"].data
nodes = {
k: torch.tensor(v, dtype=dist_graph.idtype) for k, v in nodes.items()
}
if not use_graphbolt and dist_graph.local_partition is not None: if not use_graphbolt and dist_graph.local_partition is not None:
# Check whether etypes are sorted in dist_graph # Check whether etypes are sorted in dist_graph
...@@ -979,7 +1005,7 @@ def check_rpc_bipartite_sampling_empty( ...@@ -979,7 +1005,7 @@ def check_rpc_bipartite_sampling_empty(
pserver_list.append(p) pserver_list.append(p)
deg = get_degrees(g, orig_nids["game"], "game") deg = get_degrees(g, orig_nids["game"], "game")
empty_nids = F.nonzero_1d(deg == 0) empty_nids = F.nonzero_1d(deg == 0).to(g.idtype)
nodes = {"game": empty_nids, "user": torch.tensor([1], dtype=g.idtype)} nodes = {"game": empty_nids, "user": torch.tensor([1], dtype=g.idtype)}
block, _ = start_bipartite_sample_client( block, _ = start_bipartite_sample_client(
0, 0,
...@@ -1120,7 +1146,7 @@ def check_rpc_bipartite_etype_sampling_empty( ...@@ -1120,7 +1146,7 @@ def check_rpc_bipartite_etype_sampling_empty(
pserver_list.append(p) pserver_list.append(p)
deg = get_degrees(g, orig_nids["game"], "game") deg = get_degrees(g, orig_nids["game"], "game")
empty_nids = F.nonzero_1d(deg == 0) empty_nids = F.nonzero_1d(deg == 0).to(g.idtype)
nodes = {"game": empty_nids, "user": torch.tensor([1], dtype=g.idtype)} nodes = {"game": empty_nids, "user": torch.tensor([1], dtype=g.idtype)}
block, _ = start_bipartite_etype_sample_client( block, _ = start_bipartite_etype_sample_client(
0, 0,
...@@ -1225,7 +1251,10 @@ def check_rpc_bipartite_etype_sampling_shuffle( ...@@ -1225,7 +1251,10 @@ def check_rpc_bipartite_etype_sampling_shuffle(
@pytest.mark.parametrize("num_server", [1]) @pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("use_graphbolt", [False, True]) @pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True]) @pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_sampling_shuffle(num_server, use_graphbolt, return_eids): @pytest.mark.parametrize("node_id_dtype", [torch.int32, torch.int64])
def test_rpc_sampling_shuffle(
num_server, use_graphbolt, return_eids, node_id_dtype
):
reset_envs() reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed" os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
...@@ -1234,6 +1263,7 @@ def test_rpc_sampling_shuffle(num_server, use_graphbolt, return_eids): ...@@ -1234,6 +1263,7 @@ def test_rpc_sampling_shuffle(num_server, use_graphbolt, return_eids):
num_server, num_server,
use_graphbolt=use_graphbolt, use_graphbolt=use_graphbolt,
return_eids=return_eids, return_eids=return_eids,
node_id_dtype=node_id_dtype,
) )
......
...@@ -41,7 +41,7 @@ class NeighborSampler(object): ...@@ -41,7 +41,7 @@ class NeighborSampler(object):
def sample_blocks(self, seeds): def sample_blocks(self, seeds):
import torch as th import torch as th
seeds = th.LongTensor(np.asarray(seeds)) seeds = th.tensor(np.asarray(seeds), dtype=self.g.idtype)
blocks = [] blocks = []
for fanout in self.fanouts: for fanout in self.fanouts:
# For each seed node, sample ``fanout`` neighbors. # For each seed node, sample ``fanout`` neighbors.
...@@ -124,7 +124,7 @@ def start_dist_dataloader( ...@@ -124,7 +124,7 @@ def start_dist_dataloader(
for i in range(2): for i in range(2):
# Create DataLoader for constructing blocks # Create DataLoader for constructing blocks
dataloader = DistDataLoader( dataloader = DistDataLoader(
dataset=train_nid.numpy(), dataset=train_nid,
batch_size=batch_size, batch_size=batch_size,
collate_fn=sampler.sample_blocks, collate_fn=sampler.sample_blocks,
shuffle=False, shuffle=False,
...@@ -448,9 +448,11 @@ def start_node_dataloader( ...@@ -448,9 +448,11 @@ def start_node_dataloader(
assert len(dist_graph.ntypes) == len(groundtruth_g.ntypes) assert len(dist_graph.ntypes) == len(groundtruth_g.ntypes)
assert len(dist_graph.etypes) == len(groundtruth_g.etypes) assert len(dist_graph.etypes) == len(groundtruth_g.etypes)
if len(dist_graph.etypes) == 1: if len(dist_graph.etypes) == 1:
train_nid = th.arange(num_nodes_to_sample) train_nid = th.arange(num_nodes_to_sample, dtype=dist_graph.idtype)
else: else:
train_nid = {"n3": th.arange(num_nodes_to_sample)} train_nid = {
"n3": th.arange(num_nodes_to_sample, dtype=dist_graph.idtype)
}
for i in range(num_server): for i in range(num_server):
part, _, _, _, _, _, _ = load_partition(part_config, i) part, _, _, _, _, _, _ = load_partition(part_config, i)
...@@ -765,7 +767,7 @@ def start_multiple_dataloaders( ...@@ -765,7 +767,7 @@ def start_multiple_dataloaders(
dgl.distributed.initialize(ip_config) dgl.distributed.initialize(ip_config)
dist_g = dgl.distributed.DistGraph(graph_name, part_config=part_config) dist_g = dgl.distributed.DistGraph(graph_name, part_config=part_config)
if dataloader_type == "node": if dataloader_type == "node":
train_ids = th.arange(orig_g.num_nodes()) train_ids = th.arange(orig_g.num_nodes(), dtype=dist_g.idtype)
batch_size = orig_g.num_nodes() // 100 batch_size = orig_g.num_nodes() // 100
else: else:
train_ids = th.arange(orig_g.num_edges()) train_ids = th.arange(orig_g.num_edges())
......
...@@ -768,9 +768,18 @@ def test_dgl_partition_to_graphbolt_homo( ...@@ -768,9 +768,18 @@ def test_dgl_partition_to_graphbolt_homo(
part_config, part_id, load_feats=False, use_graphbolt=True part_config, part_id, load_feats=False, use_graphbolt=True
)[0] )[0]
orig_indptr, orig_indices, orig_eids = orig_g.adj().csc() orig_indptr, orig_indices, orig_eids = orig_g.adj().csc()
# The original graph is in int64 while the partitioned graph is in
# int32 as dtype formatting is applied when converting to graphbolt
# format.
assert orig_indptr.dtype == th.int64
assert orig_indices.dtype == th.int64
assert new_g.csc_indptr.dtype == th.int32
assert new_g.indices.dtype == th.int32
assert th.equal(orig_indptr, new_g.csc_indptr) assert th.equal(orig_indptr, new_g.csc_indptr)
assert th.equal(orig_indices, new_g.indices) assert th.equal(orig_indices, new_g.indices)
assert new_g.node_type_offset is None assert new_g.node_type_offset is None
assert orig_g.ndata[dgl.NID].dtype == th.int64
assert new_g.node_attributes[dgl.NID].dtype == th.int32
assert th.equal( assert th.equal(
orig_g.ndata[dgl.NID], new_g.node_attributes[dgl.NID] orig_g.ndata[dgl.NID], new_g.node_attributes[dgl.NID]
) )
...@@ -782,6 +791,8 @@ def test_dgl_partition_to_graphbolt_homo( ...@@ -782,6 +791,8 @@ def test_dgl_partition_to_graphbolt_homo(
else: else:
assert "inner_node" not in new_g.node_attributes assert "inner_node" not in new_g.node_attributes
if store_eids or debug_mode: if store_eids or debug_mode:
assert orig_g.edata[dgl.EID].dtype == th.int64
assert new_g.edge_attributes[dgl.EID].dtype == th.int32
assert th.equal( assert th.equal(
orig_g.edata[dgl.EID][orig_eids], orig_g.edata[dgl.EID][orig_eids],
new_g.edge_attributes[dgl.EID], new_g.edge_attributes[dgl.EID],
...@@ -789,6 +800,8 @@ def test_dgl_partition_to_graphbolt_homo( ...@@ -789,6 +800,8 @@ def test_dgl_partition_to_graphbolt_homo(
else: else:
assert dgl.EID not in new_g.edge_attributes assert dgl.EID not in new_g.edge_attributes
if store_inner_edge or debug_mode: if store_inner_edge or debug_mode:
assert orig_g.edata["inner_edge"].dtype == th.uint8
assert new_g.edge_attributes["inner_edge"].dtype == th.uint8
assert th.equal( assert th.equal(
orig_g.edata["inner_edge"][orig_eids], orig_g.edata["inner_edge"][orig_eids],
new_g.edge_attributes["inner_edge"], new_g.edge_attributes["inner_edge"],
...@@ -838,8 +851,17 @@ def test_dgl_partition_to_graphbolt_hetero( ...@@ -838,8 +851,17 @@ def test_dgl_partition_to_graphbolt_hetero(
part_config, part_id, load_feats=False, use_graphbolt=True part_config, part_id, load_feats=False, use_graphbolt=True
)[0] )[0]
orig_indptr, orig_indices, orig_eids = orig_g.adj().csc() orig_indptr, orig_indices, orig_eids = orig_g.adj().csc()
# The original graph is in int64 while the partitioned graph is in
# int32 as dtype formatting is applied when converting to graphbolt
# format.
assert orig_indptr.dtype == th.int64
assert orig_indices.dtype == th.int64
assert new_g.csc_indptr.dtype == th.int32
assert new_g.indices.dtype == th.int32
assert th.equal(orig_indptr, new_g.csc_indptr) assert th.equal(orig_indptr, new_g.csc_indptr)
assert th.equal(orig_indices, new_g.indices) assert th.equal(orig_indices, new_g.indices)
assert orig_g.ndata[dgl.NID].dtype == th.int64
assert new_g.node_attributes[dgl.NID].dtype == th.int32
assert th.equal( assert th.equal(
orig_g.ndata[dgl.NID], new_g.node_attributes[dgl.NID] orig_g.ndata[dgl.NID], new_g.node_attributes[dgl.NID]
) )
...@@ -851,12 +873,16 @@ def test_dgl_partition_to_graphbolt_hetero( ...@@ -851,12 +873,16 @@ def test_dgl_partition_to_graphbolt_hetero(
else: else:
assert "inner_node" not in new_g.node_attributes assert "inner_node" not in new_g.node_attributes
if debug_mode: if debug_mode:
assert orig_g.ndata[dgl.NTYPE].dtype == th.int32
assert new_g.node_attributes[dgl.NTYPE].dtype == th.int8
assert th.equal( assert th.equal(
orig_g.ndata[dgl.NTYPE], new_g.node_attributes[dgl.NTYPE] orig_g.ndata[dgl.NTYPE], new_g.node_attributes[dgl.NTYPE]
) )
else: else:
assert dgl.NTYPE not in new_g.node_attributes assert dgl.NTYPE not in new_g.node_attributes
if store_eids or debug_mode: if store_eids or debug_mode:
assert orig_g.edata[dgl.EID].dtype == th.int64
assert new_g.edge_attributes[dgl.EID].dtype == th.int32
assert th.equal( assert th.equal(
orig_g.edata[dgl.EID][orig_eids], orig_g.edata[dgl.EID][orig_eids],
new_g.edge_attributes[dgl.EID], new_g.edge_attributes[dgl.EID],
...@@ -864,6 +890,8 @@ def test_dgl_partition_to_graphbolt_hetero( ...@@ -864,6 +890,8 @@ def test_dgl_partition_to_graphbolt_hetero(
else: else:
assert dgl.EID not in new_g.edge_attributes assert dgl.EID not in new_g.edge_attributes
if store_inner_edge or debug_mode: if store_inner_edge or debug_mode:
assert orig_g.edata["inner_edge"].dtype == th.uint8
assert new_g.edge_attributes["inner_edge"].dtype == th.uint8
assert th.equal( assert th.equal(
orig_g.edata["inner_edge"], orig_g.edata["inner_edge"],
new_g.edge_attributes["inner_edge"], new_g.edge_attributes["inner_edge"],
...@@ -871,6 +899,8 @@ def test_dgl_partition_to_graphbolt_hetero( ...@@ -871,6 +899,8 @@ def test_dgl_partition_to_graphbolt_hetero(
else: else:
assert "inner_edge" not in new_g.edge_attributes assert "inner_edge" not in new_g.edge_attributes
if debug_mode: if debug_mode:
assert orig_g.edata[dgl.ETYPE].dtype == th.int32
assert new_g.edge_attributes[dgl.ETYPE].dtype == th.int8
assert th.equal( assert th.equal(
orig_g.edata[dgl.ETYPE][orig_eids], orig_g.edata[dgl.ETYPE][orig_eids],
new_g.edge_attributes[dgl.ETYPE], new_g.edge_attributes[dgl.ETYPE],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment