"examples/pytorch/vscode:/vscode.git/clone" did not exist on "ba110e50e61f19e6cce3a7bf6166d92000e2641d"
Unverified Commit ea58090e authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Add names to `ItemSets` in `_init_all_nodes_set` (#6461)

parent 4fca8817
......@@ -11,6 +11,7 @@ import yaml
import dgl
from ...base import dgl_warning
from ...data.utils import download, extract_archive
from ..base import etype_str_to_tuple
from ..dataset import Dataset, Task
......@@ -498,13 +499,16 @@ class OnDiskDataset(Dataset):
def _init_all_nodes_set(self, graph) -> Union[ItemSet, ItemSetDict]:
if graph is None:
dgl_warning(
"`all_node_set` is returned as None, since graph is None."
)
return None
num_nodes = graph.num_nodes
if isinstance(num_nodes, int):
return ItemSet(num_nodes)
return ItemSet(num_nodes, names="seed_nodes")
else:
data = {
node_type: ItemSet(num_node)
node_type: ItemSet(num_node, names="seed_nodes")
for node_type, num_node in num_nodes.items()
}
return ItemSetDict(data)
......
......@@ -1806,6 +1806,7 @@ def test_OnDiskDataset_all_nodes_set_homo():
dataset = gb.OnDiskDataset(test_dir).load()
all_nodes_set = dataset.all_nodes_set
assert isinstance(all_nodes_set, gb.ItemSet)
assert all_nodes_set.names == ("seed_nodes",)
for i, item in enumerate(all_nodes_set):
assert i == item
......@@ -1842,6 +1843,7 @@ def test_OnDiskDataset_all_nodes_set_hetero():
dataset = gb.OnDiskDataset(test_dir).load()
all_nodes_set = dataset.all_nodes_set
assert isinstance(all_nodes_set, gb.ItemSetDict)
assert all_nodes_set.names == ("seed_nodes",)
for i, item in enumerate(all_nodes_set):
assert len(item) == 1
assert isinstance(item, dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment