"...sampling/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "b8886900837e3fc73972215b7c5e9b3e127acbfc"
Unverified Commit 0bd79a2b authored by Xinyu Yao's avatar Xinyu Yao Committed by GitHub
Browse files

[GraphBolt] Update name used in `LegacyDataset`. (#7283)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent d9caeaaa
...@@ -33,7 +33,7 @@ class LegacyDataset(Dataset): ...@@ -33,7 +33,7 @@ class LegacyDataset(Dataset):
for key in idx.keys(): for key in idx.keys():
item_set = ItemSet( item_set = ItemSet(
(idx[key], labels[key][idx[key]]), (idx[key], labels[key][idx[key]]),
names=("seed_nodes", "labels"), names=("seeds", "labels"),
) )
item_set_dict[key] = item_set item_set_dict[key] = item_set
return ItemSetDict(item_set_dict) return ItemSetDict(item_set_dict)
...@@ -58,7 +58,7 @@ class LegacyDataset(Dataset): ...@@ -58,7 +58,7 @@ class LegacyDataset(Dataset):
item_set_dict = {} item_set_dict = {}
for ntype in graph.ntypes: for ntype in graph.ntypes:
item_set = ItemSet(graph.num_nodes(ntype), names="seed_nodes") item_set = ItemSet(graph.num_nodes(ntype), names="seeds")
item_set_dict[ntype] = item_set item_set_dict[ntype] = item_set
self._all_nodes_set = ItemSetDict(item_set_dict) self._all_nodes_set = ItemSetDict(item_set_dict)
...@@ -100,21 +100,21 @@ class LegacyDataset(Dataset): ...@@ -100,21 +100,21 @@ class LegacyDataset(Dataset):
test_labels = legacy[0].ndata["label"][legacy.test_idx] test_labels = legacy[0].ndata["label"][legacy.test_idx]
train_set = ItemSet( train_set = ItemSet(
(legacy.train_idx, train_labels), (legacy.train_idx, train_labels),
names=("seed_nodes", "labels"), names=("seeds", "labels"),
) )
validation_set = ItemSet( validation_set = ItemSet(
(legacy.val_idx, validation_labels), (legacy.val_idx, validation_labels),
names=("seed_nodes", "labels"), names=("seeds", "labels"),
) )
test_set = ItemSet( test_set = ItemSet(
(legacy.test_idx, test_labels), names=("seed_nodes", "labels") (legacy.test_idx, test_labels), names=("seeds", "labels")
) )
task = OnDiskTask(metadata, train_set, validation_set, test_set) task = OnDiskTask(metadata, train_set, validation_set, test_set)
tasks.append(task) tasks.append(task)
self._tasks = tasks self._tasks = tasks
num_nodes = legacy[0].num_nodes() num_nodes = legacy[0].num_nodes()
self._all_nodes_set = ItemSet(num_nodes, names="seed_nodes") self._all_nodes_set = ItemSet(num_nodes, names="seeds")
features = {} features = {}
for name in legacy[0].ndata.keys(): for name in legacy[0].ndata.keys():
tensor = legacy[0].ndata[name] tensor = legacy[0].ndata[name]
......
...@@ -12,11 +12,11 @@ def test_LegacyDataset_homo_node_pred(): ...@@ -12,11 +12,11 @@ def test_LegacyDataset_homo_node_pred():
# Check tasks. # Check tasks.
assert len(dataset.tasks) == 1 assert len(dataset.tasks) == 1
task = dataset.tasks[0] task = dataset.tasks[0]
assert task.train_set.names == ("seed_nodes", "labels") assert task.train_set.names == ("seeds", "labels")
assert len(task.train_set) == 140 assert len(task.train_set) == 140
assert task.validation_set.names == ("seed_nodes", "labels") assert task.validation_set.names == ("seeds", "labels")
assert len(task.validation_set) == 500 assert len(task.validation_set) == 500
assert task.test_set.names == ("seed_nodes", "labels") assert task.test_set.names == ("seeds", "labels")
assert len(task.test_set) == 1000 assert len(task.test_set) == 1000
assert task.metadata["num_classes"] == 7 assert task.metadata["num_classes"] == 7
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment