Unverified Commit ea58090e authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Add names to `ItemSets` in `_init_all_nodes_set` (#6461)

parent 4fca8817
...@@ -11,6 +11,7 @@ import yaml ...@@ -11,6 +11,7 @@ import yaml
import dgl import dgl
from ...base import dgl_warning
from ...data.utils import download, extract_archive from ...data.utils import download, extract_archive
from ..base import etype_str_to_tuple from ..base import etype_str_to_tuple
from ..dataset import Dataset, Task from ..dataset import Dataset, Task
...@@ -498,13 +499,16 @@ class OnDiskDataset(Dataset): ...@@ -498,13 +499,16 @@ class OnDiskDataset(Dataset):
def _init_all_nodes_set(self, graph) -> Union[ItemSet, ItemSetDict]: def _init_all_nodes_set(self, graph) -> Union[ItemSet, ItemSetDict]:
if graph is None: if graph is None:
dgl_warning(
"`all_node_set` is returned as None, since graph is None."
)
return None return None
num_nodes = graph.num_nodes num_nodes = graph.num_nodes
if isinstance(num_nodes, int): if isinstance(num_nodes, int):
return ItemSet(num_nodes) return ItemSet(num_nodes, names="seed_nodes")
else: else:
data = { data = {
node_type: ItemSet(num_node) node_type: ItemSet(num_node, names="seed_nodes")
for node_type, num_node in num_nodes.items() for node_type, num_node in num_nodes.items()
} }
return ItemSetDict(data) return ItemSetDict(data)
......
...@@ -1806,6 +1806,7 @@ def test_OnDiskDataset_all_nodes_set_homo(): ...@@ -1806,6 +1806,7 @@ def test_OnDiskDataset_all_nodes_set_homo():
dataset = gb.OnDiskDataset(test_dir).load() dataset = gb.OnDiskDataset(test_dir).load()
all_nodes_set = dataset.all_nodes_set all_nodes_set = dataset.all_nodes_set
assert isinstance(all_nodes_set, gb.ItemSet) assert isinstance(all_nodes_set, gb.ItemSet)
assert all_nodes_set.names == ("seed_nodes",)
for i, item in enumerate(all_nodes_set): for i, item in enumerate(all_nodes_set):
assert i == item assert i == item
...@@ -1842,6 +1843,7 @@ def test_OnDiskDataset_all_nodes_set_hetero(): ...@@ -1842,6 +1843,7 @@ def test_OnDiskDataset_all_nodes_set_hetero():
dataset = gb.OnDiskDataset(test_dir).load() dataset = gb.OnDiskDataset(test_dir).load()
all_nodes_set = dataset.all_nodes_set all_nodes_set = dataset.all_nodes_set
assert isinstance(all_nodes_set, gb.ItemSetDict) assert isinstance(all_nodes_set, gb.ItemSetDict)
assert all_nodes_set.names == ("seed_nodes",)
for i, item in enumerate(all_nodes_set): for i, item in enumerate(all_nodes_set):
assert len(item) == 1 assert len(item) == 1
assert isinstance(item, dict) assert isinstance(item, dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment