Unverified Commit 5d2d1453 authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Support saving graph and TVT_Sets in `int32` (#7127)


Co-authored-by: default avatarMuhammed Fatih BALIN <m.f.balin@gmail.com>
parent 90a9136c
"""GraphBolt OnDiskDataset.""" """GraphBolt OnDiskDataset."""
import bisect
import json import json
import os import os
import shutil import shutil
...@@ -40,11 +41,20 @@ from .torch_based_feature_store import TorchBasedFeatureStore ...@@ -40,11 +41,20 @@ from .torch_based_feature_store import TorchBasedFeatureStore
__all__ = ["OnDiskDataset", "preprocess_ondisk_dataset", "BuiltinDataset"] __all__ = ["OnDiskDataset", "preprocess_ondisk_dataset", "BuiltinDataset"]
NAMES_INDICATING_NODE_IDS = [
"seed_nodes",
"node_pairs",
"seeds",
"negative_srcs",
"negative_dsts",
]
def _graph_data_to_fused_csc_sampling_graph( def _graph_data_to_fused_csc_sampling_graph(
dataset_dir: str, dataset_dir: str,
graph_data: Dict, graph_data: Dict,
include_original_edge_id: bool, include_original_edge_id: bool,
auto_cast_to_optimal_dtype: bool,
) -> FusedCSCSamplingGraph: ) -> FusedCSCSamplingGraph:
"""Convert the raw graph data into FusedCSCSamplingGraph. """Convert the raw graph data into FusedCSCSamplingGraph.
...@@ -56,6 +66,9 @@ def _graph_data_to_fused_csc_sampling_graph( ...@@ -56,6 +66,9 @@ def _graph_data_to_fused_csc_sampling_graph(
The raw data read from yaml file. The raw data read from yaml file.
include_original_edge_id : bool include_original_edge_id : bool
Whether to include the original edge id in the FusedCSCSamplingGraph. Whether to include the original edge id in the FusedCSCSamplingGraph.
auto_cast_to_optimal_dtype: bool, optional
Casts the dtypes of tensors in the dataset into smallest possible dtypes
for reduced storage requirements and potentially increased performance.
Returns Returns
------- -------
...@@ -83,6 +96,14 @@ def _graph_data_to_fused_csc_sampling_graph( ...@@ -83,6 +96,14 @@ def _graph_data_to_fused_csc_sampling_graph(
del coo_tensor del coo_tensor
indptr, indices, edge_ids = sparse_matrix.csc() indptr, indices, edge_ids = sparse_matrix.csc()
del sparse_matrix del sparse_matrix
if auto_cast_to_optimal_dtype:
if num_nodes <= torch.iinfo(torch.int32).max:
indices = indices.to(torch.int32)
if num_edges <= torch.iinfo(torch.int32).max:
indptr = indptr.to(torch.int32)
edge_ids = edge_ids.to(torch.int32)
node_type_offset = None node_type_offset = None
type_per_edge = None type_per_edge = None
node_type_to_id = None node_type_to_id = None
...@@ -127,6 +148,14 @@ def _graph_data_to_fused_csc_sampling_graph( ...@@ -127,6 +148,14 @@ def _graph_data_to_fused_csc_sampling_graph(
del coo_src_list del coo_src_list
coo_dst = torch.cat(coo_dst_list) coo_dst = torch.cat(coo_dst_list)
del coo_dst_list del coo_dst_list
if auto_cast_to_optimal_dtype:
dtypes = [torch.uint8, torch.int16, torch.int32, torch.int64]
dtype_maxes = [torch.iinfo(dtype).max for dtype in dtypes]
dtype_id = bisect.bisect_left(dtype_maxes, len(edge_type_to_id) - 1)
etype_dtype = dtypes[dtype_id]
coo_etype_list = [
tensor.to(etype_dtype) for tensor in coo_etype_list
]
coo_etype = torch.cat(coo_etype_list) coo_etype = torch.cat(coo_etype_list)
del coo_etype_list del coo_etype_list
...@@ -137,17 +166,32 @@ def _graph_data_to_fused_csc_sampling_graph( ...@@ -137,17 +166,32 @@ def _graph_data_to_fused_csc_sampling_graph(
del coo_src, coo_dst del coo_src, coo_dst
indptr, indices, edge_ids = sparse_matrix.csc() indptr, indices, edge_ids = sparse_matrix.csc()
del sparse_matrix del sparse_matrix
node_type_offset = torch.tensor(node_type_offset)
if auto_cast_to_optimal_dtype:
if total_num_nodes <= torch.iinfo(torch.int32).max:
indices = indices.to(torch.int32)
if total_num_edges <= torch.iinfo(torch.int32).max:
indptr = indptr.to(torch.int32)
edge_ids = edge_ids.to(torch.int32)
node_type_offset = torch.tensor(node_type_offset, dtype=indices.dtype)
type_per_edge = torch.index_select(coo_etype, dim=0, index=edge_ids) type_per_edge = torch.index_select(coo_etype, dim=0, index=edge_ids)
del coo_etype del coo_etype
node_attributes = {} node_attributes = {}
edge_attributes = {} edge_attributes = {}
if include_original_edge_id: if include_original_edge_id:
edge_ids -= torch.gather( # If uint8 or int16 was chosen above for etypes, we cast to int.
input=torch.tensor(edge_type_offset), temp_etypes = (
type_per_edge.int()
if type_per_edge.element_size() < 4
else type_per_edge
)
edge_ids -= torch.index_select(
torch.tensor(edge_type_offset, dtype=edge_ids.dtype),
dim=0, dim=0,
index=type_per_edge, index=temp_etypes,
) )
del temp_etypes
edge_attributes[ORIGINAL_EDGE_ID] = edge_ids edge_attributes[ORIGINAL_EDGE_ID] = edge_ids
# Load the sampling related node/edge features and add them to # Load the sampling related node/edge features and add them to
...@@ -279,6 +323,7 @@ def preprocess_ondisk_dataset( ...@@ -279,6 +323,7 @@ def preprocess_ondisk_dataset(
dataset_dir: str, dataset_dir: str,
include_original_edge_id: bool = False, include_original_edge_id: bool = False,
force_preprocess: bool = None, force_preprocess: bool = None,
auto_cast_to_optimal_dtype: bool = True,
) -> str: ) -> str:
"""Preprocess the on-disk dataset. Parse the input config file, """Preprocess the on-disk dataset. Parse the input config file,
load the data, and save the data in the format that GraphBolt supports. load the data, and save the data in the format that GraphBolt supports.
...@@ -291,6 +336,10 @@ def preprocess_ondisk_dataset( ...@@ -291,6 +336,10 @@ def preprocess_ondisk_dataset(
Whether to include the original edge id in the FusedCSCSamplingGraph. Whether to include the original edge id in the FusedCSCSamplingGraph.
force_preprocess: bool, optional force_preprocess: bool, optional
Whether to force reload the ondisk dataset. Whether to force reload the ondisk dataset.
auto_cast_to_optimal_dtype: bool, optional
Casts the dtypes of tensors in the dataset into smallest possible dtypes
for reduced storage requirements and potentially increased performance.
Default is True.
Returns Returns
------- -------
...@@ -360,6 +409,7 @@ def preprocess_ondisk_dataset( ...@@ -360,6 +409,7 @@ def preprocess_ondisk_dataset(
dataset_dir, dataset_dir,
input_config["graph"], input_config["graph"],
include_original_edge_id, include_original_edge_id,
auto_cast_to_optimal_dtype,
) )
# 3. Record value of include_original_edge_id. # 3. Record value of include_original_edge_id.
...@@ -372,6 +422,10 @@ def preprocess_ondisk_dataset( ...@@ -372,6 +422,10 @@ def preprocess_ondisk_dataset(
processed_dir_prefix, "fused_csc_sampling_graph.pt" processed_dir_prefix, "fused_csc_sampling_graph.pt"
) )
node_ids_within_int32 = (
sampling_graph.indices.dtype == torch.int32
and auto_cast_to_optimal_dtype
)
torch.save( torch.save(
sampling_graph, sampling_graph,
os.path.join( os.path.join(
...@@ -379,6 +433,7 @@ def preprocess_ondisk_dataset( ...@@ -379,6 +433,7 @@ def preprocess_ondisk_dataset(
output_config["graph_topology"]["path"], output_config["graph_topology"]["path"],
), ),
) )
del sampling_graph
del output_config["graph"] del output_config["graph"]
# 5. Load the node/edge features and do necessary conversion. # 5. Load the node/edge features and do necessary conversion.
...@@ -428,11 +483,16 @@ def preprocess_ondisk_dataset( ...@@ -428,11 +483,16 @@ def preprocess_ondisk_dataset(
processed_dir_prefix, processed_dir_prefix,
input_data["path"].replace("pt", "npy"), input_data["path"].replace("pt", "npy"),
) )
name = (
input_data["name"] if "name" in input_data else None
)
copy_or_convert_data( copy_or_convert_data(
os.path.join(dataset_dir, input_data["path"]), os.path.join(dataset_dir, input_data["path"]),
os.path.join(dataset_dir, output_data["path"]), os.path.join(dataset_dir, output_data["path"]),
input_data["format"], input_data["format"],
output_data["format"], output_data["format"],
within_int32=node_ids_within_int32
and name in NAMES_INDICATING_NODE_IDS,
) )
# 7. Save the output_config. # 7. Save the output_config.
...@@ -610,6 +670,10 @@ class OnDiskDataset(Dataset): ...@@ -610,6 +670,10 @@ class OnDiskDataset(Dataset):
Whether to include the original edge id in the FusedCSCSamplingGraph. Whether to include the original edge id in the FusedCSCSamplingGraph.
force_preprocess: bool, optional force_preprocess: bool, optional
Whether to force reload the ondisk dataset. Whether to force reload the ondisk dataset.
auto_cast_to_optimal_dtype: bool, optional
Casts the dtypes of tensors in the dataset into smallest possible dtypes
for reduced storage requirements and potentially increased performance.
Default is True.
""" """
def __init__( def __init__(
...@@ -617,12 +681,16 @@ class OnDiskDataset(Dataset): ...@@ -617,12 +681,16 @@ class OnDiskDataset(Dataset):
path: str, path: str,
include_original_edge_id: bool = False, include_original_edge_id: bool = False,
force_preprocess: bool = None, force_preprocess: bool = None,
auto_cast_to_optimal_dtype: bool = True,
) -> None: ) -> None:
# Always call the preprocess function first. If already preprocessed, # Always call the preprocess function first. If already preprocessed,
# the function will return the original path directly. # the function will return the original path directly.
self._dataset_dir = path self._dataset_dir = path
yaml_path = preprocess_ondisk_dataset( yaml_path = preprocess_ondisk_dataset(
path, include_original_edge_id, force_preprocess path,
include_original_edge_id,
force_preprocess,
auto_cast_to_optimal_dtype,
) )
with open(yaml_path) as f: with open(yaml_path) as f:
self._yaml_data = yaml.load(f, Loader=yaml.loader.SafeLoader) self._yaml_data = yaml.load(f, Loader=yaml.loader.SafeLoader)
...@@ -824,7 +892,7 @@ class OnDiskDataset(Dataset): ...@@ -824,7 +892,7 @@ class OnDiskDataset(Dataset):
def _init_all_nodes_set(self, graph) -> Union[ItemSet, ItemSetDict]: def _init_all_nodes_set(self, graph) -> Union[ItemSet, ItemSetDict]:
if graph is None: if graph is None:
dgl_warning( dgl_warning(
"`all_node_set` is returned as None, since graph is None." "`all_nodes_set` is returned as None, since graph is None."
) )
return None return None
num_nodes = graph.num_nodes num_nodes = graph.num_nodes
......
...@@ -85,6 +85,17 @@ def get_npy_dim(npy_path): ...@@ -85,6 +85,17 @@ def get_npy_dim(npy_path):
return len(shape) return len(shape)
def _to_int32(data):
if isinstance(data, torch.Tensor):
return data.to(torch.int32)
elif isinstance(data, np.ndarray):
return data.astype(np.int32)
else:
raise TypeError(
"Unsupported input type. Please provide a torch tensor or numpy array."
)
def copy_or_convert_data( def copy_or_convert_data(
input_path, input_path,
output_path, output_path,
...@@ -92,27 +103,30 @@ def copy_or_convert_data( ...@@ -92,27 +103,30 @@ def copy_or_convert_data(
output_format="numpy", output_format="numpy",
in_memory=True, in_memory=True,
is_feature=False, is_feature=False,
within_int32=False,
): ):
"""Copy or convert the data from input_path to output_path.""" """Copy or convert the data from input_path to output_path."""
assert ( assert (
output_format == "numpy" output_format == "numpy"
), "The output format of the data should be numpy." ), "The output format of the data should be numpy."
os.makedirs(os.path.dirname(output_path), exist_ok=True) os.makedirs(os.path.dirname(output_path), exist_ok=True)
# If the original format is numpy, just copy the file. # We read the data always in case we need to cast its type.
data = read_data(input_path, input_format, in_memory)
if within_int32:
data = _to_int32(data)
if input_format == "numpy": if input_format == "numpy":
# If dim of the data is 1, reshape it to n * 1 and save it to output_path. # If dim of the data is 1, reshape it to n * 1 and save it to output_path.
if is_feature and get_npy_dim(input_path) == 1: if is_feature and get_npy_dim(input_path) == 1:
data = read_data(input_path, input_format, in_memory)
data = data.reshape(-1, 1) data = data.reshape(-1, 1)
save_data(data, output_path, output_format) # If the data does not need to be modified, just copy the file.
else: elif not within_int32:
shutil.copyfile(input_path, output_path) shutil.copyfile(input_path, output_path)
return
else: else:
# If the original format is not numpy, convert it to numpy. # If dim of the data is 1, reshape it to n * 1 and save it to output_path.
data = read_data(input_path, input_format, in_memory)
if is_feature and data.dim() == 1: if is_feature and data.dim() == 1:
data = data.reshape(-1, 1) data = data.reshape(-1, 1)
save_data(data, output_path, output_format) save_data(data, output_path, output_format)
def get_attributes(_obj) -> list: def get_attributes(_obj) -> list:
......
...@@ -1151,7 +1151,11 @@ def test_OnDiskDataset_preprocess_homogeneous(edge_fmt): ...@@ -1151,7 +1151,11 @@ def test_OnDiskDataset_preprocess_homogeneous(edge_fmt):
num_samples = 100 num_samples = 100
fanout = 1 fanout = 1
subgraph = fused_csc_sampling_graph.sample_neighbors( subgraph = fused_csc_sampling_graph.sample_neighbors(
torch.arange(num_samples), torch.arange(
0,
num_samples,
dtype=fused_csc_sampling_graph.indices.dtype,
),
torch.tensor([fanout]), torch.tensor([fanout]),
) )
assert len(subgraph.sampled_csc.indices) <= num_samples assert len(subgraph.sampled_csc.indices) <= num_samples
...@@ -1191,7 +1195,10 @@ def test_OnDiskDataset_preprocess_homogeneous(edge_fmt): ...@@ -1191,7 +1195,10 @@ def test_OnDiskDataset_preprocess_homogeneous(edge_fmt):
fused_csc_sampling_graph = None fused_csc_sampling_graph = None
def test_OnDiskDataset_preprocess_homogeneous_hardcode(edge_fmt="numpy"): @pytest.mark.parametrize("auto_cast", [False, True])
def test_OnDiskDataset_preprocess_homogeneous_hardcode(
auto_cast, edge_fmt="numpy"
):
"""Test preprocess of OnDiskDataset.""" """Test preprocess of OnDiskDataset."""
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
"""Original graph in COO: """Original graph in COO:
...@@ -1312,6 +1319,7 @@ def test_OnDiskDataset_preprocess_homogeneous_hardcode(edge_fmt="numpy"): ...@@ -1312,6 +1319,7 @@ def test_OnDiskDataset_preprocess_homogeneous_hardcode(edge_fmt="numpy"):
output_file = gb.ondisk_dataset.preprocess_ondisk_dataset( output_file = gb.ondisk_dataset.preprocess_ondisk_dataset(
test_dir, test_dir,
include_original_edge_id=True, include_original_edge_id=True,
auto_cast_to_optimal_dtype=auto_cast,
) )
with open(output_file, "rb") as f: with open(output_file, "rb") as f:
...@@ -1351,16 +1359,31 @@ def test_OnDiskDataset_preprocess_homogeneous_hardcode(edge_fmt="numpy"): ...@@ -1351,16 +1359,31 @@ def test_OnDiskDataset_preprocess_homogeneous_hardcode(edge_fmt="numpy"):
torch.tensor([7, 8, 0, 9, 1, 2, 3, 4, 5, 6]), torch.tensor([7, 8, 0, 9, 1, 2, 3, 4, 5, 6]),
) )
expected_dtype = torch.int32 if auto_cast else torch.int64
assert fused_csc_sampling_graph.csc_indptr.dtype == expected_dtype
assert fused_csc_sampling_graph.indices.dtype == expected_dtype
assert (
fused_csc_sampling_graph.edge_attributes[gb.ORIGINAL_EDGE_ID].dtype
== expected_dtype
)
num_samples = 5 num_samples = 5
fanout = 1 fanout = 1
subgraph = fused_csc_sampling_graph.sample_neighbors( subgraph = fused_csc_sampling_graph.sample_neighbors(
torch.arange(num_samples), torch.arange(
0,
num_samples,
dtype=fused_csc_sampling_graph.indices.dtype,
),
torch.tensor([fanout]), torch.tensor([fanout]),
) )
assert len(subgraph.sampled_csc.indices) <= num_samples assert len(subgraph.sampled_csc.indices) <= num_samples
def test_OnDiskDataset_preprocess_heterogeneous_hardcode(edge_fmt="numpy"): @pytest.mark.parametrize("auto_cast", [False, True])
def test_OnDiskDataset_preprocess_heterogeneous_hardcode(
auto_cast, edge_fmt="numpy"
):
"""Test preprocess of OnDiskDataset.""" """Test preprocess of OnDiskDataset."""
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
"""Original graph in COO: """Original graph in COO:
...@@ -1507,6 +1530,7 @@ def test_OnDiskDataset_preprocess_heterogeneous_hardcode(edge_fmt="numpy"): ...@@ -1507,6 +1530,7 @@ def test_OnDiskDataset_preprocess_heterogeneous_hardcode(edge_fmt="numpy"):
output_file = gb.ondisk_dataset.preprocess_ondisk_dataset( output_file = gb.ondisk_dataset.preprocess_ondisk_dataset(
test_dir, test_dir,
include_original_edge_id=True, include_original_edge_id=True,
auto_cast_to_optimal_dtype=auto_cast,
) )
with open(output_file, "rb") as f: with open(output_file, "rb") as f:
...@@ -1548,6 +1572,18 @@ def test_OnDiskDataset_preprocess_heterogeneous_hardcode(edge_fmt="numpy"): ...@@ -1548,6 +1572,18 @@ def test_OnDiskDataset_preprocess_heterogeneous_hardcode(edge_fmt="numpy"):
fused_csc_sampling_graph.edge_attributes[gb.ORIGINAL_EDGE_ID], fused_csc_sampling_graph.edge_attributes[gb.ORIGINAL_EDGE_ID],
torch.tensor([0, 1, 0, 2, 0, 1, 2, 0, 1, 2]), torch.tensor([0, 1, 0, 2, 0, 1, 2, 0, 1, 2]),
) )
expected_dtype = torch.int32 if auto_cast else torch.int64
assert fused_csc_sampling_graph.csc_indptr.dtype == expected_dtype
assert fused_csc_sampling_graph.indices.dtype == expected_dtype
assert (
fused_csc_sampling_graph.edge_attributes[gb.ORIGINAL_EDGE_ID].dtype
== expected_dtype
)
assert fused_csc_sampling_graph.node_type_offset.dtype == expected_dtype
expected_etype_dtype = torch.uint8 if auto_cast else torch.int64
assert (
fused_csc_sampling_graph.type_per_edge.dtype == expected_etype_dtype
)
def test_OnDiskDataset_preprocess_path(): def test_OnDiskDataset_preprocess_path():
...@@ -2622,9 +2658,12 @@ def test_BuiltinDataset(): ...@@ -2622,9 +2658,12 @@ def test_BuiltinDataset():
_ = gb.BuiltinDataset(name=dataset_name, root=test_dir).load() _ = gb.BuiltinDataset(name=dataset_name, root=test_dir).load()
@pytest.mark.parametrize("auto_cast", [True, False])
@pytest.mark.parametrize("include_original_edge_id", [True, False]) @pytest.mark.parametrize("include_original_edge_id", [True, False])
@pytest.mark.parametrize("edge_fmt", ["csv", "numpy"]) @pytest.mark.parametrize("edge_fmt", ["csv", "numpy"])
def test_OnDiskDataset_homogeneous(include_original_edge_id, edge_fmt): def test_OnDiskDataset_homogeneous(
auto_cast, include_original_edge_id, edge_fmt
):
"""Preprocess and instantiate OnDiskDataset for homogeneous graph.""" """Preprocess and instantiate OnDiskDataset for homogeneous graph."""
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
# All metadata fields are specified. # All metadata fields are specified.
...@@ -2647,7 +2686,9 @@ def test_OnDiskDataset_homogeneous(include_original_edge_id, edge_fmt): ...@@ -2647,7 +2686,9 @@ def test_OnDiskDataset_homogeneous(include_original_edge_id, edge_fmt):
f.write(yaml_content) f.write(yaml_content)
dataset = gb.OnDiskDataset( dataset = gb.OnDiskDataset(
test_dir, include_original_edge_id=include_original_edge_id test_dir,
include_original_edge_id=include_original_edge_id,
auto_cast_to_optimal_dtype=auto_cast,
).load() ).load()
assert dataset.dataset_name == dataset_name assert dataset.dataset_name == dataset_name
...@@ -2673,6 +2714,10 @@ def test_OnDiskDataset_homogeneous(include_original_edge_id, edge_fmt): ...@@ -2673,6 +2714,10 @@ def test_OnDiskDataset_homogeneous(include_original_edge_id, edge_fmt):
assert isinstance(tasks[0].train_set, gb.ItemSet) assert isinstance(tasks[0].train_set, gb.ItemSet)
assert isinstance(tasks[0].validation_set, gb.ItemSet) assert isinstance(tasks[0].validation_set, gb.ItemSet)
assert isinstance(tasks[0].test_set, gb.ItemSet) assert isinstance(tasks[0].test_set, gb.ItemSet)
assert tasks[0].train_set._items[0].dtype == graph.indices.dtype
assert tasks[0].validation_set._items[0].dtype == graph.indices.dtype
assert tasks[0].test_set._items[0].dtype == graph.indices.dtype
assert dataset.all_nodes_set._items.dtype == graph.indices.dtype
assert tasks[0].metadata["num_classes"] == num_classes assert tasks[0].metadata["num_classes"] == num_classes
assert tasks[0].metadata["name"] == "link_prediction" assert tasks[0].metadata["name"] == "link_prediction"
...@@ -2683,6 +2728,7 @@ def test_OnDiskDataset_homogeneous(include_original_edge_id, edge_fmt): ...@@ -2683,6 +2728,7 @@ def test_OnDiskDataset_homogeneous(include_original_edge_id, edge_fmt):
tasks[0].train_set, tasks[0].train_set,
tasks[0].validation_set, tasks[0].validation_set,
tasks[0].test_set, tasks[0].test_set,
dataset.all_nodes_set,
]: ]:
datapipe = gb.ItemSampler(itemset, batch_size=10) datapipe = gb.ItemSampler(itemset, batch_size=10)
datapipe = datapipe.sample_neighbor(graph, [-1]) datapipe = datapipe.sample_neighbor(graph, [-1])
...@@ -2698,9 +2744,12 @@ def test_OnDiskDataset_homogeneous(include_original_edge_id, edge_fmt): ...@@ -2698,9 +2744,12 @@ def test_OnDiskDataset_homogeneous(include_original_edge_id, edge_fmt):
dataset = None dataset = None
@pytest.mark.parametrize("auto_cast", [True, False])
@pytest.mark.parametrize("include_original_edge_id", [True, False]) @pytest.mark.parametrize("include_original_edge_id", [True, False])
@pytest.mark.parametrize("edge_fmt", ["csv", "numpy"]) @pytest.mark.parametrize("edge_fmt", ["csv", "numpy"])
def test_OnDiskDataset_heterogeneous(include_original_edge_id, edge_fmt): def test_OnDiskDataset_heterogeneous(
auto_cast, include_original_edge_id, edge_fmt
):
"""Preprocess and instantiate OnDiskDataset for heterogeneous graph.""" """Preprocess and instantiate OnDiskDataset for heterogeneous graph."""
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
dataset_name = "OnDiskDataset_hetero" dataset_name = "OnDiskDataset_hetero"
...@@ -2723,7 +2772,9 @@ def test_OnDiskDataset_heterogeneous(include_original_edge_id, edge_fmt): ...@@ -2723,7 +2772,9 @@ def test_OnDiskDataset_heterogeneous(include_original_edge_id, edge_fmt):
) )
dataset = gb.OnDiskDataset( dataset = gb.OnDiskDataset(
test_dir, include_original_edge_id=include_original_edge_id test_dir,
include_original_edge_id=include_original_edge_id,
auto_cast_to_optimal_dtype=auto_cast,
).load() ).load()
assert dataset.dataset_name == dataset_name assert dataset.dataset_name == dataset_name
...@@ -2736,6 +2787,8 @@ def test_OnDiskDataset_heterogeneous(include_original_edge_id, edge_fmt): ...@@ -2736,6 +2787,8 @@ def test_OnDiskDataset_heterogeneous(include_original_edge_id, edge_fmt):
assert graph.total_num_edges == sum( assert graph.total_num_edges == sum(
num_edge for num_edge in num_edges.values() num_edge for num_edge in num_edges.values()
) )
expected_dtype = torch.int32 if auto_cast else torch.int64
assert graph.indices.dtype == expected_dtype
assert ( assert (
graph.node_attributes is not None graph.node_attributes is not None
and "feat" in graph.node_attributes and "feat" in graph.node_attributes
...@@ -2763,6 +2816,7 @@ def test_OnDiskDataset_heterogeneous(include_original_edge_id, edge_fmt): ...@@ -2763,6 +2816,7 @@ def test_OnDiskDataset_heterogeneous(include_original_edge_id, edge_fmt):
tasks[0].train_set, tasks[0].train_set,
tasks[0].validation_set, tasks[0].validation_set,
tasks[0].test_set, tasks[0].test_set,
dataset.all_nodes_set,
]: ]:
datapipe = gb.ItemSampler(itemset, batch_size=10) datapipe = gb.ItemSampler(itemset, batch_size=10)
datapipe = datapipe.sample_neighbor(graph, [-1]) datapipe = datapipe.sample_neighbor(graph, [-1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment