Unverified Commit 0bd79a2b authored by Xinyu Yao's avatar Xinyu Yao Committed by GitHub
Browse files

[GraphBolt] Update name used in `LegacyDataset`. (#7283)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent d9caeaaa
......@@ -33,7 +33,7 @@ class LegacyDataset(Dataset):
for key in idx.keys():
item_set = ItemSet(
(idx[key], labels[key][idx[key]]),
names=("seed_nodes", "labels"),
names=("seeds", "labels"),
)
item_set_dict[key] = item_set
return ItemSetDict(item_set_dict)
......@@ -58,7 +58,7 @@ class LegacyDataset(Dataset):
item_set_dict = {}
for ntype in graph.ntypes:
item_set = ItemSet(graph.num_nodes(ntype), names="seed_nodes")
item_set = ItemSet(graph.num_nodes(ntype), names="seeds")
item_set_dict[ntype] = item_set
self._all_nodes_set = ItemSetDict(item_set_dict)
......@@ -100,21 +100,21 @@ class LegacyDataset(Dataset):
test_labels = legacy[0].ndata["label"][legacy.test_idx]
train_set = ItemSet(
(legacy.train_idx, train_labels),
names=("seed_nodes", "labels"),
names=("seeds", "labels"),
)
validation_set = ItemSet(
(legacy.val_idx, validation_labels),
names=("seed_nodes", "labels"),
names=("seeds", "labels"),
)
test_set = ItemSet(
(legacy.test_idx, test_labels), names=("seed_nodes", "labels")
(legacy.test_idx, test_labels), names=("seeds", "labels")
)
task = OnDiskTask(metadata, train_set, validation_set, test_set)
tasks.append(task)
self._tasks = tasks
num_nodes = legacy[0].num_nodes()
self._all_nodes_set = ItemSet(num_nodes, names="seed_nodes")
self._all_nodes_set = ItemSet(num_nodes, names="seeds")
features = {}
for name in legacy[0].ndata.keys():
tensor = legacy[0].ndata[name]
......
......@@ -12,11 +12,11 @@ def test_LegacyDataset_homo_node_pred():
# Check tasks.
assert len(dataset.tasks) == 1
task = dataset.tasks[0]
assert task.train_set.names == ("seed_nodes", "labels")
assert task.train_set.names == ("seeds", "labels")
assert len(task.train_set) == 140
assert task.validation_set.names == ("seed_nodes", "labels")
assert task.validation_set.names == ("seeds", "labels")
assert len(task.validation_set) == 500
assert task.test_set.names == ("seed_nodes", "labels")
assert task.test_set.names == ("seeds", "labels")
assert len(task.test_set) == 1000
assert task.metadata["num_classes"] == 7
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment