Unverified Commit 6068dc31 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] add names for each item in ItemSet (#6254)

parent bbc8ff62
...@@ -446,7 +446,8 @@ class OnDiskDataset(Dataset): ...@@ -446,7 +446,8 @@ class OnDiskDataset(Dataset):
tuple( tuple(
read_data(data.path, data.format, data.in_memory) read_data(data.path, data.format, data.in_memory)
for data in tvt_set[0].data for data in tvt_set[0].data
) ),
names=tuple(data.name for data in tvt_set[0].data),
) )
else: else:
data = {} data = {}
...@@ -455,7 +456,8 @@ class OnDiskDataset(Dataset): ...@@ -455,7 +456,8 @@ class OnDiskDataset(Dataset):
tuple( tuple(
read_data(data.path, data.format, data.in_memory) read_data(data.path, data.format, data.in_memory)
for data in tvt.data for data in tvt.data
) ),
names=tuple(data.name for data in tvt.data),
) )
ret = ItemSetDict(data) ret = ItemSetDict(data)
return ret return ret
...@@ -28,6 +28,7 @@ class OnDiskFeatureDataFormat(str, Enum): ...@@ -28,6 +28,7 @@ class OnDiskFeatureDataFormat(str, Enum):
class OnDiskTVTSetData(pydantic.BaseModel): class OnDiskTVTSetData(pydantic.BaseModel):
"""Train-Validation-Test set data.""" """Train-Validation-Test set data."""
name: Optional[str] = None
format: OnDiskFeatureDataFormat format: OnDiskFeatureDataFormat
in_memory: Optional[bool] = True in_memory: Optional[bool] = True
path: str path: str
......
...@@ -47,11 +47,26 @@ class ItemSet: ...@@ -47,11 +47,26 @@ class ItemSet:
(tensor(4), tensor(9), tensor([18, 19]))] (tensor(4), tensor(9), tensor([18, 19]))]
""" """
def __init__(self, items: Iterable or Tuple[Iterable]) -> None: def __init__(
self,
items: Iterable or Tuple[Iterable],
names: str or Tuple[str] = None,
) -> None:
if isinstance(items, tuple): if isinstance(items, tuple):
self._items = items self._items = items
else: else:
self._items = (items,) self._items = (items,)
if names is not None:
if isinstance(names, tuple):
self._names = names
else:
self._names = (names,)
assert len(self._items) == len(self._names), (
f"Number of items ({len(self._items)}) and "
f"names ({len(self._names)}) must match."
)
else:
self._names = None
def __iter__(self) -> Iterator: def __iter__(self) -> Iterator:
if len(self._items) == 1: if len(self._items) == 1:
...@@ -68,6 +83,11 @@ class ItemSet: ...@@ -68,6 +83,11 @@ class ItemSet:
f"{type(self).__name__} instance doesn't have valid length." f"{type(self).__name__} instance doesn't have valid length."
) )
@property
def names(self) -> Tuple[str]:
"""Return the names of the items."""
return self._names
class ItemSetDict: class ItemSetDict:
r"""An iterable ItemsetDict. r"""An iterable ItemsetDict.
...@@ -127,6 +147,10 @@ class ItemSetDict: ...@@ -127,6 +147,10 @@ class ItemSetDict:
def __init__(self, itemsets: Dict[str, ItemSet]) -> None: def __init__(self, itemsets: Dict[str, ItemSet]) -> None:
self._itemsets = itemsets self._itemsets = itemsets
self._names = itemsets[list(itemsets.keys())[0]].names
assert all(
self._names == itemset.names for itemset in itemsets.values()
), "All itemsets must have the same names."
def __iter__(self) -> Iterator: def __iter__(self) -> Iterator:
for key, itemset in self._itemsets.items(): for key, itemset in self._itemsets.items():
...@@ -135,3 +159,8 @@ class ItemSetDict: ...@@ -135,3 +159,8 @@ class ItemSetDict:
def __len__(self) -> int: def __len__(self) -> int:
return sum(len(itemset) for itemset in self._itemsets.values()) return sum(len(itemset) for itemset in self._itemsets.values())
@property
def names(self) -> Tuple[str]:
"""Return the names of the items."""
return self._names
...@@ -61,6 +61,103 @@ def test_OnDiskDataset_TVTSet_exceptions(): ...@@ -61,6 +61,103 @@ def test_OnDiskDataset_TVTSet_exceptions():
_ = gb.OnDiskDataset(test_dir).load() _ = gb.OnDiskDataset(test_dir).load()
def test_OnDiskDataset_TVTSet_ItemSet_names():
"""Test TVTSet which returns ItemSet with IDs, labels and corresponding names."""
with tempfile.TemporaryDirectory() as test_dir:
train_ids = np.arange(1000)
train_ids_path = os.path.join(test_dir, "train_ids.npy")
np.save(train_ids_path, train_ids)
train_labels = np.random.randint(0, 10, size=1000)
train_labels_path = os.path.join(test_dir, "train_labels.npy")
np.save(train_labels_path, train_labels)
yaml_content = f"""
tasks:
- name: node_classification
num_classes: 10
train_set:
- type: null
data:
- name: seed_node
format: numpy
in_memory: true
path: {train_ids_path}
- name: label
format: numpy
in_memory: true
path: {train_labels_path}
- format: numpy
in_memory: true
path: {train_labels_path}
"""
os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True)
yaml_file = os.path.join(test_dir, "preprocessed/metadata.yaml")
with open(yaml_file, "w") as f:
f.write(yaml_content)
dataset = gb.OnDiskDataset(test_dir).load()
# Verify train set.
train_set = dataset.tasks[0].train_set
assert len(train_set) == 1000
assert isinstance(train_set, gb.ItemSet)
for i, (id, label, _) in enumerate(train_set):
assert id == train_ids[i]
assert label == train_labels[i]
assert train_set.names == ("seed_node", "label", None)
train_set = None
def test_OnDiskDataset_TVTSet_ItemSetDict_names():
"""Test TVTSet which returns ItemSet with IDs, labels and corresponding names."""
with tempfile.TemporaryDirectory() as test_dir:
train_ids = np.arange(1000)
train_ids_path = os.path.join(test_dir, "train_ids.npy")
np.save(train_ids_path, train_ids)
train_labels = np.random.randint(0, 10, size=1000)
train_labels_path = os.path.join(test_dir, "train_labels.npy")
np.save(train_labels_path, train_labels)
yaml_content = f"""
tasks:
- name: node_classification
num_classes: 10
train_set:
- type: "author:writes:paper"
data:
- name: seed_node
format: numpy
in_memory: true
path: {train_ids_path}
- name: label
format: numpy
in_memory: true
path: {train_labels_path}
- format: numpy
in_memory: true
path: {train_labels_path}
"""
os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True)
yaml_file = os.path.join(test_dir, "preprocessed/metadata.yaml")
with open(yaml_file, "w") as f:
f.write(yaml_content)
dataset = gb.OnDiskDataset(test_dir).load()
# Verify train set.
train_set = dataset.tasks[0].train_set
assert len(train_set) == 1000
assert isinstance(train_set, gb.ItemSetDict)
for i, item in enumerate(train_set):
assert isinstance(item, dict)
assert "author:writes:paper" in item
id, label, _ = item["author:writes:paper"]
assert id == train_ids[i]
assert label == train_labels[i]
assert train_set.names == ("seed_node", "label", None)
train_set = None
def test_OnDiskDataset_TVTSet_ItemSet_id_label(): def test_OnDiskDataset_TVTSet_ItemSet_id_label():
"""Test TVTSet which returns ItemSet with IDs and labels.""" """Test TVTSet which returns ItemSet with IDs and labels."""
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
...@@ -96,27 +193,33 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -96,27 +193,33 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
train_set: train_set:
- type: null - type: null
data: data:
- format: numpy - name: seed_node
format: numpy
in_memory: true in_memory: true
path: {train_ids_path} path: {train_ids_path}
- format: numpy - name: label
format: numpy
in_memory: true in_memory: true
path: {train_labels_path} path: {train_labels_path}
validation_set: validation_set:
- data: - data:
- format: numpy - name: seed_node
format: numpy
in_memory: true in_memory: true
path: {validation_ids_path} path: {validation_ids_path}
- format: numpy - name: label
format: numpy
in_memory: true in_memory: true
path: {validation_labels_path} path: {validation_labels_path}
test_set: test_set:
- type: null - type: null
data: data:
- format: numpy - name: seed_node
format: numpy
in_memory: true in_memory: true
path: {test_ids_path} path: {test_ids_path}
- format: numpy - name: label
format: numpy
in_memory: true in_memory: true
path: {test_labels_path} path: {test_labels_path}
""" """
...@@ -139,6 +242,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -139,6 +242,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
for i, (id, label) in enumerate(train_set): for i, (id, label) in enumerate(train_set):
assert id == train_ids[i] assert id == train_ids[i]
assert label == train_labels[i] assert label == train_labels[i]
assert train_set.names == ("seed_node", "label")
train_set = None train_set = None
# Verify validation set. # Verify validation set.
...@@ -148,6 +252,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -148,6 +252,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
for i, (id, label) in enumerate(validation_set): for i, (id, label) in enumerate(validation_set):
assert id == validation_ids[i] assert id == validation_ids[i]
assert label == validation_labels[i] assert label == validation_labels[i]
assert validation_set.names == ("seed_node", "label")
validation_set = None validation_set = None
# Verify test set. # Verify test set.
...@@ -157,6 +262,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -157,6 +262,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
for i, (id, label) in enumerate(test_set): for i, (id, label) in enumerate(test_set):
assert id == test_ids[i] assert id == test_ids[i]
assert label == test_labels[i] assert label == test_labels[i]
assert test_set.names == ("seed_node", "label")
test_set = None test_set = None
dataset = None dataset = None
...@@ -220,36 +326,45 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label(): ...@@ -220,36 +326,45 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
train_set: train_set:
- type: null - type: null
data: data:
- format: numpy - name: src
format: numpy
in_memory: true in_memory: true
path: {train_src_path} path: {train_src_path}
- format: numpy - name: dst
format: numpy
in_memory: true in_memory: true
path: {train_dst_path} path: {train_dst_path}
- format: numpy - name: label
format: numpy
in_memory: true in_memory: true
path: {train_labels_path} path: {train_labels_path}
validation_set: validation_set:
- data: - data:
- format: numpy - name: src
format: numpy
in_memory: true in_memory: true
path: {validation_src_path} path: {validation_src_path}
- format: numpy - name: dst
format: numpy
in_memory: true in_memory: true
path: {validation_dst_path} path: {validation_dst_path}
- format: numpy - name: label
format: numpy
in_memory: true in_memory: true
path: {validation_labels_path} path: {validation_labels_path}
test_set: test_set:
- type: null - type: null
data: data:
- format: numpy - name: src
format: numpy
in_memory: true in_memory: true
path: {test_src_path} path: {test_src_path}
- format: numpy - name: dst
format: numpy
in_memory: true in_memory: true
path: {test_dst_path} path: {test_dst_path}
- format: numpy - name: label
format: numpy
in_memory: true in_memory: true
path: {test_labels_path} path: {test_labels_path}
""" """
...@@ -268,6 +383,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label(): ...@@ -268,6 +383,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
assert src == train_src[i] assert src == train_src[i]
assert dst == train_dst[i] assert dst == train_dst[i]
assert label == train_labels[i] assert label == train_labels[i]
assert train_set.names == ("src", "dst", "label")
train_set = None train_set = None
# Verify validation set. # Verify validation set.
...@@ -278,6 +394,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label(): ...@@ -278,6 +394,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
assert src == validation_src[i] assert src == validation_src[i]
assert dst == validation_dst[i] assert dst == validation_dst[i]
assert label == validation_labels[i] assert label == validation_labels[i]
assert validation_set.names == ("src", "dst", "label")
validation_set = None validation_set = None
# Verify test set. # Verify test set.
...@@ -288,6 +405,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label(): ...@@ -288,6 +405,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
assert src == test_src[i] assert src == test_src[i]
assert dst == test_dst[i] assert dst == test_dst[i]
assert label == test_labels[i] assert label == test_labels[i]
assert test_set.names == ("src", "dst", "label")
test_set = None test_set = None
dataset = None dataset = None
...@@ -335,36 +453,45 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs(): ...@@ -335,36 +453,45 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs():
train_set: train_set:
- type: null - type: null
data: data:
- format: numpy - name: src
format: numpy
in_memory: true in_memory: true
path: {train_src_path} path: {train_src_path}
- format: numpy - name: dst
format: numpy
in_memory: true in_memory: true
path: {train_dst_path} path: {train_dst_path}
- format: numpy - name: negative_dst
format: numpy
in_memory: true in_memory: true
path: {train_neg_dst_path} path: {train_neg_dst_path}
validation_set: validation_set:
- data: - data:
- format: numpy - name: src
format: numpy
in_memory: true in_memory: true
path: {validation_src_path} path: {validation_src_path}
- format: numpy - name: dst
format: numpy
in_memory: true in_memory: true
path: {validation_dst_path} path: {validation_dst_path}
- format: numpy - name: negative_dst
format: numpy
in_memory: true in_memory: true
path: {validation_neg_dst_path} path: {validation_neg_dst_path}
test_set: test_set:
- type: null - type: null
data: data:
- format: numpy - name: src
format: numpy
in_memory: true in_memory: true
path: {test_src_path} path: {test_src_path}
- format: numpy - name: dst
format: numpy
in_memory: true in_memory: true
path: {test_dst_path} path: {test_dst_path}
- format: numpy - name: negative_dst
format: numpy
in_memory: true in_memory: true
path: {test_neg_dst_path} path: {test_neg_dst_path}
""" """
...@@ -383,6 +510,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs(): ...@@ -383,6 +510,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs():
assert src == train_src[i] assert src == train_src[i]
assert dst == train_dst[i] assert dst == train_dst[i]
assert torch.equal(negs, torch.from_numpy(train_neg_dst[i])) assert torch.equal(negs, torch.from_numpy(train_neg_dst[i]))
assert train_set.names == ("src", "dst", "negative_dst")
train_set = None train_set = None
# Verify validation set. # Verify validation set.
...@@ -393,6 +521,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs(): ...@@ -393,6 +521,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs():
assert src == validation_src[i] assert src == validation_src[i]
assert dst == validation_dst[i] assert dst == validation_dst[i]
assert torch.equal(negs, torch.from_numpy(validation_neg_dst[i])) assert torch.equal(negs, torch.from_numpy(validation_neg_dst[i]))
assert validation_set.names == ("src", "dst", "negative_dst")
validation_set = None validation_set = None
# Verify test set. # Verify test set.
...@@ -403,6 +532,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs(): ...@@ -403,6 +532,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs():
assert src == test_src[i] assert src == test_src[i]
assert dst == test_dst[i] assert dst == test_dst[i]
assert torch.equal(negs, torch.from_numpy(test_neg_dst[i])) assert torch.equal(negs, torch.from_numpy(test_neg_dst[i]))
assert test_set.names == ("src", "dst", "negative_dst")
test_set = None test_set = None
dataset = None dataset = None
...@@ -434,31 +564,37 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): ...@@ -434,31 +564,37 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
train_set: train_set:
- type: paper - type: paper
data: data:
- format: numpy - name: seed_node
format: numpy
in_memory: true in_memory: true
path: {train_path} path: {train_path}
- type: author - type: author
data: data:
- format: numpy - name: seed_node
format: numpy
path: {train_path} path: {train_path}
validation_set: validation_set:
- type: paper - type: paper
data: data:
- format: numpy - name: seed_node
format: numpy
path: {validation_path} path: {validation_path}
- type: author - type: author
data: data:
- format: numpy - name: seed_node
format: numpy
path: {validation_path} path: {validation_path}
test_set: test_set:
- type: paper - type: paper
data: data:
- format: numpy - name: seed_node
format: numpy
in_memory: false in_memory: false
path: {test_path} path: {test_path}
- type: author - type: author
data: data:
- format: numpy - name: seed_node
format: numpy
path: {test_path} path: {test_path}
""" """
os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True) os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True)
...@@ -480,6 +616,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): ...@@ -480,6 +616,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
id, label = item[key] id, label = item[key]
assert id == train_ids[i % 1000] assert id == train_ids[i % 1000]
assert label == train_labels[i % 1000] assert label == train_labels[i % 1000]
assert train_set.names == ("seed_node",)
train_set = None train_set = None
# Verify validation set. # Verify validation set.
...@@ -494,6 +631,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): ...@@ -494,6 +631,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
id, label = item[key] id, label = item[key]
assert id == validation_ids[i % 1000] assert id == validation_ids[i % 1000]
assert label == validation_labels[i % 1000] assert label == validation_labels[i % 1000]
assert validation_set.names == ("seed_node",)
validation_set = None validation_set = None
# Verify test set. # Verify test set.
...@@ -508,6 +646,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): ...@@ -508,6 +646,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
id, label = item[key] id, label = item[key]
assert id == test_ids[i % 1000] assert id == test_ids[i % 1000]
assert label == test_labels[i % 1000] assert label == test_labels[i % 1000]
assert test_set.names == ("seed_node",)
test_set = None test_set = None
dataset = None dataset = None
...@@ -539,31 +678,37 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label(): ...@@ -539,31 +678,37 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
train_set: train_set:
- type: paper - type: paper
data: data:
- format: numpy - name: node_pair
format: numpy
in_memory: true in_memory: true
path: {train_path} path: {train_path}
- type: author - type: author
data: data:
- format: numpy - name: node_pair
format: numpy
path: {train_path} path: {train_path}
validation_set: validation_set:
- type: paper - type: paper
data: data:
- format: numpy - name: node_pair
format: numpy
path: {validation_path} path: {validation_path}
- type: author - type: author
data: data:
- format: numpy - name: node_pair
format: numpy
path: {validation_path} path: {validation_path}
test_set: test_set:
- type: paper - type: paper
data: data:
- format: numpy - name: node_pair
format: numpy
in_memory: false in_memory: false
path: {test_path} path: {test_path}
- type: author - type: author
data: data:
- format: numpy - name: node_pair
format: numpy
path: {test_path} path: {test_path}
""" """
os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True) os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True)
...@@ -586,6 +731,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label(): ...@@ -586,6 +731,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
assert src == train_pairs[0][i % 1000] assert src == train_pairs[0][i % 1000]
assert dst == train_pairs[1][i % 1000] assert dst == train_pairs[1][i % 1000]
assert label == train_labels[i % 1000] assert label == train_labels[i % 1000]
assert train_set.names == ("node_pair",)
train_set = None train_set = None
# Verify validation set. # Verify validation set.
...@@ -601,6 +747,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label(): ...@@ -601,6 +747,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
assert src == validation_pairs[0][i % 1000] assert src == validation_pairs[0][i % 1000]
assert dst == validation_pairs[1][i % 1000] assert dst == validation_pairs[1][i % 1000]
assert label == validation_labels[i % 1000] assert label == validation_labels[i % 1000]
assert validation_set.names == ("node_pair",)
validation_set = None validation_set = None
# Verify test set. # Verify test set.
...@@ -616,6 +763,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label(): ...@@ -616,6 +763,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
assert src == test_pairs[0][i % 1000] assert src == test_pairs[0][i % 1000]
assert dst == test_pairs[1][i % 1000] assert dst == test_pairs[1][i % 1000]
assert label == test_labels[i % 1000] assert label == test_labels[i % 1000]
assert test_set.names == ("node_pair",)
test_set = None test_set = None
dataset = None dataset = None
......
import re
import dgl import dgl
import pytest import pytest
import torch import torch
...@@ -5,6 +7,81 @@ from dgl import graphbolt as gb ...@@ -5,6 +7,81 @@ from dgl import graphbolt as gb
from torch.testing import assert_close from torch.testing import assert_close
def test_ItemSet_names():
# ItemSet with single name.
item_set = gb.ItemSet(torch.arange(0, 5), names="seed_node")
assert item_set.names == ("seed_node",)
# ItemSet with multiple names.
item_set = gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)), names=("seed_node", "label")
)
assert item_set.names == ("seed_node", "label")
# ItemSet with no name.
item_set = gb.ItemSet(torch.arange(0, 5))
assert item_set.names is None
# ItemSet with mismatched items and names.
with pytest.raises(
AssertionError,
match=re.escape("Number of items (1) and names (2) must match."),
):
_ = gb.ItemSet(torch.arange(0, 5), names=("seed_node", "label"))
def test_ItemSetDict_names():
# ItemSetDict with single name.
item_set = gb.ItemSetDict(
{
"user": gb.ItemSet(torch.arange(0, 5), names="seed_node"),
"item": gb.ItemSet(torch.arange(5, 10), names="seed_node"),
}
)
assert item_set.names == ("seed_node",)
# ItemSetDict with multiple names.
item_set = gb.ItemSetDict(
{
"user": gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_node", "label"),
),
"item": gb.ItemSet(
(torch.arange(5, 10), torch.arange(10, 15)),
names=("seed_node", "label"),
),
}
)
assert item_set.names == ("seed_node", "label")
# ItemSetDict with no name.
item_set = gb.ItemSetDict(
{
"user": gb.ItemSet(torch.arange(0, 5)),
"item": gb.ItemSet(torch.arange(5, 10)),
}
)
assert item_set.names is None
# ItemSetDict with mismatched items and names.
with pytest.raises(
AssertionError,
match=re.escape("All itemsets must have the same names."),
):
_ = gb.ItemSetDict(
{
"user": gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_node", "label"),
),
"item": gb.ItemSet(
(torch.arange(5, 10),), names=("seed_node",)
),
}
)
def test_ItemSet_valid_length(): def test_ItemSet_valid_length():
# Single iterable. # Single iterable.
ids = torch.arange(0, 5) ids = torch.arange(0, 5)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment