Unverified Commit e5199ac2 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add `__repr__` to `OnDiskTask`. (#6804)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
Co-authored-by: default avatarRhett Ying <85214957+Rhett-Ying@users.noreply.github.com>
parent 77bf0ac6
......@@ -14,7 +14,7 @@ from ...base import dgl_warning
from ...data.utils import download, extract_archive
from ..base import etype_str_to_tuple
from ..dataset import Dataset, Task
from ..internal import copy_or_convert_data, read_data
from ..internal import copy_or_convert_data, get_attributes, read_data
from ..itemset import ItemSet, ItemSetDict
from ..sampling_graph import SamplingGraph
from .fused_csc_sampling_graph import from_dglgraph, FusedCSCSamplingGraph
......@@ -270,6 +270,9 @@ class OnDiskTask:
"""Return the test set."""
return self._test_set
def __repr__(self) -> str:
return _ondisk_task_str(self)
class OnDiskDataset(Dataset):
"""An on-disk dataset which reads graph topology, feature data and
......@@ -609,3 +612,25 @@ class BuiltinDataset(OnDiskDataset):
extract_archive(zip_file_path, root, overwrite=True)
os.remove(zip_file_path)
super().__init__(dataset_dir)
def _ondisk_task_str(task: OnDiskTask) -> str:
final_str = "OnDiskTask("
indent_len = len(final_str)
def _add_indent(_str, indent):
lines = _str.split("\n")
lines = [lines[0]] + [" " * indent + line for line in lines[1:]]
return "\n".join(lines)
attributes = get_attributes(task)
attributes.reverse()
for name in attributes:
if name[0] == "_":
continue
val = getattr(task, name)
final_str += (
f"{name}={_add_indent(str(val), indent_len + len(name) + 1)},\n"
+ " " * indent_len
)
return final_str[:-indent_len] + ")"
......@@ -2163,3 +2163,63 @@ def test_OnDiskDataset_heterogeneous(include_original_edge_id):
graph = None
tasks = None
dataset = None
def test_OnDiskTask_repr_homogeneous():
item_set = gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_nodes", "labels"),
)
metadata = {"name": "node_classification"}
task = gb.OnDiskTask(metadata, item_set, item_set, item_set)
expected_str = str(
"""OnDiskTask(validation_set=ItemSet(items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),
names=('seed_nodes', 'labels'),
),
train_set=ItemSet(items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),
names=('seed_nodes', 'labels'),
),
test_set=ItemSet(items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),
names=('seed_nodes', 'labels'),
),
metadata={'name': 'node_classification'},
)"""
)
assert str(task) == expected_str, print(task)
def test_OnDiskTask_repr_heterogeneous():
item_set = gb.ItemSetDict(
{
"user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"),
"item": gb.ItemSet(torch.arange(5, 10), names="seed_nodes"),
}
)
metadata = {"name": "node_classification"}
task = gb.OnDiskTask(metadata, item_set, item_set, item_set)
expected_str = str(
"""OnDiskTask(validation_set=ItemSetDict(items={'user': ItemSet(items=(tensor([0, 1, 2, 3, 4]),),
names=('seed_nodes',),
), 'item': ItemSet(items=(tensor([5, 6, 7, 8, 9]),),
names=('seed_nodes',),
)},
names=('seed_nodes',),
),
train_set=ItemSetDict(items={'user': ItemSet(items=(tensor([0, 1, 2, 3, 4]),),
names=('seed_nodes',),
), 'item': ItemSet(items=(tensor([5, 6, 7, 8, 9]),),
names=('seed_nodes',),
)},
names=('seed_nodes',),
),
test_set=ItemSetDict(items={'user': ItemSet(items=(tensor([0, 1, 2, 3, 4]),),
names=('seed_nodes',),
), 'item': ItemSet(items=(tensor([5, 6, 7, 8, 9]),),
names=('seed_nodes',),
)},
names=('seed_nodes',),
),
metadata={'name': 'node_classification'},
)"""
)
assert str(task) == expected_str, print(task)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment