Unverified Commit 47d37e91 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] convert TVT from list of list to list (#6080)

parent 12ade95c
"""GraphBolt Dataset.""" """GraphBolt Dataset."""
from typing import Dict, List from typing import Dict
from .feature_store import FeatureStore from .feature_store import FeatureStore
from .itemset import ItemSet, ItemSetDict from .itemset import ItemSet, ItemSetDict
...@@ -32,18 +32,18 @@ class Dataset: ...@@ -32,18 +32,18 @@ class Dataset:
""" """
@property @property
def train_sets(self) -> List[ItemSet] or List[ItemSetDict]: def train_set(self) -> ItemSet or ItemSetDict:
"""Return the training sets.""" """Return the training set."""
raise NotImplementedError raise NotImplementedError
@property @property
def validation_sets(self) -> List[ItemSet] or List[ItemSetDict]: def validation_set(self) -> ItemSet or ItemSetDict:
"""Return the validation sets.""" """Return the validation set."""
raise NotImplementedError raise NotImplementedError
@property @property
def test_sets(self) -> List[ItemSet] or List[ItemSetDict]: def test_set(self) -> ItemSet or ItemSetDict:
"""Return the test sets.""" """Return the test set."""
raise NotImplementedError raise NotImplementedError
@property @property
......
...@@ -165,45 +165,42 @@ def preprocess_ondisk_dataset(input_config_path: str) -> str: ...@@ -165,45 +165,42 @@ def preprocess_ondisk_dataset(input_config_path: str) -> str:
) )
# 7. Save the train/val/test split according to the output_config. # 7. Save the train/val/test split according to the output_config.
for set_name in ["train_sets", "validation_sets", "test_sets"]: for set_name in ["train_set", "validation_set", "test_set"]:
if set_name not in input_config: if set_name not in input_config:
continue continue
for intput_set_split, output_set_split in zip( for input_set_per_type, output_set_per_type in zip(
input_config[set_name], output_config[set_name] input_config[set_name], output_config[set_name]
): ):
for input_set_per_type, output_set_per_type in zip( for input_data, output_data in zip(
intput_set_split, output_set_split input_set_per_type["data"], output_set_per_type["data"]
): ):
for input_data, output_data in zip( # Always save the feature in numpy format.
input_set_per_type["data"], output_set_per_type["data"] output_data["format"] = "numpy"
): output_data["path"] = str(
# Always save the feature in numpy format. processed_dir_prefix
output_data["format"] = "numpy" / input_data["path"].replace("pt", "npy")
output_data["path"] = str( )
processed_dir_prefix if input_data["format"] == "numpy":
/ input_data["path"].replace("pt", "npy") # If the original format is numpy, just copy the file.
os.makedirs(
dataset_path / os.path.dirname(output_data["path"]),
exist_ok=True,
)
shutil.copy(
dataset_path / input_data["path"],
dataset_path / output_data["path"],
)
else:
# If the original format is not numpy, convert it to numpy.
input_set = read_data(
dataset_path / input_data["path"],
input_data["format"],
)
save_data(
input_set,
dataset_path / output_data["path"],
output_set_per_type["format"],
) )
if input_data["format"] == "numpy":
# If the original format is numpy, just copy the file.
os.makedirs(
dataset_path / os.path.dirname(output_data["path"]),
exist_ok=True,
)
shutil.copy(
dataset_path / input_data["path"],
dataset_path / output_data["path"],
)
else:
# If the original format is not numpy, convert it to numpy.
input_set = read_data(
dataset_path / input_data["path"],
input_data["format"],
)
save_data(
input_set,
dataset_path / output_data["path"],
output_set_per_type["format"],
)
# 8. Save the output_config. # 8. Save the output_config.
output_config_path = dataset_path / "output_config.yaml" output_config_path = dataset_path / "output_config.yaml"
...@@ -245,27 +242,27 @@ class OnDiskDataset(Dataset): ...@@ -245,27 +242,27 @@ class OnDiskDataset(Dataset):
format: numpy format: numpy
in_memory: false in_memory: false
path: edge_data/author-writes-paper-feat.npy path: edge_data/author-writes-paper-feat.npy
train_sets: train_set:
- - type: paper # could be null for homogeneous graph. - type: paper # could be null for homogeneous graph.
data: # multiple data sources could be specified. data: # multiple data sources could be specified.
- format: numpy - format: numpy
in_memory: true # If not specified, default to true. in_memory: true # If not specified, default to true.
path: set/paper-train-src.npy path: set/paper-train-src.npy
- format: numpy - format: numpy
in_memory: false in_memory: false
path: set/paper-train-dst.npy path: set/paper-train-dst.npy
validation_sets: validation_set:
- - type: paper - type: paper
data: data:
- format: numpy - format: numpy
in_memory: true in_memory: true
path: set/paper-validation.npy path: set/paper-validation.npy
test_sets: test_set:
- - type: paper - type: paper
data: data:
- format: numpy - format: numpy
in_memory: true in_memory: true
path: set/paper-test.npy path: set/paper-test.npy
Parameters Parameters
---------- ----------
...@@ -285,24 +282,24 @@ class OnDiskDataset(Dataset): ...@@ -285,24 +282,24 @@ class OnDiskDataset(Dataset):
self._num_labels = self._meta.num_labels self._num_labels = self._meta.num_labels
self._graph = self._load_graph(self._meta.graph_topology) self._graph = self._load_graph(self._meta.graph_topology)
self._feature = load_feature_stores(self._meta.feature_data) self._feature = load_feature_stores(self._meta.feature_data)
self._train_sets = self._init_tvt_sets(self._meta.train_sets) self._train_set = self._init_tvt_set(self._meta.train_set)
self._validation_sets = self._init_tvt_sets(self._meta.validation_sets) self._validation_set = self._init_tvt_set(self._meta.validation_set)
self._test_sets = self._init_tvt_sets(self._meta.test_sets) self._test_set = self._init_tvt_set(self._meta.test_set)
@property @property
def train_sets(self) -> List[ItemSet] or List[ItemSetDict]: def train_set(self) -> ItemSet or ItemSetDict:
"""Return the training set.""" """Return the training set."""
return self._train_sets return self._train_set
@property @property
def validation_sets(self) -> List[ItemSet] or List[ItemSetDict]: def validation_set(self) -> ItemSet or ItemSetDict:
"""Return the validation set.""" """Return the validation set."""
return self._validation_sets return self._validation_set
@property @property
def test_sets(self) -> List[ItemSet] or List[ItemSetDict]: def test_set(self) -> ItemSet or ItemSetDict:
"""Return the test set.""" """Return the test set."""
return self._test_sets return self._test_set
@property @property
def graph(self) -> object: def graph(self) -> object:
...@@ -341,36 +338,31 @@ class OnDiskDataset(Dataset): ...@@ -341,36 +338,31 @@ class OnDiskDataset(Dataset):
f"Graph topology type {graph_topology.type} is not supported." f"Graph topology type {graph_topology.type} is not supported."
) )
def _init_tvt_sets( def _init_tvt_set(
self, tvt_sets: List[List[OnDiskTVTSet]] self, tvt_set: List[OnDiskTVTSet]
) -> List[ItemSet] or List[ItemSetDict]: ) -> ItemSet or ItemSetDict:
"""Initialize the TVT sets.""" """Initialize the TVT set."""
if (tvt_sets is None) or (len(tvt_sets) == 0): ret = None
return None if (tvt_set is None) or (len(tvt_set) == 0):
ret = [] return ret
for tvt_set in tvt_sets: if tvt_set[0].type is None:
if (tvt_set is None) or (len(tvt_set) == 0): assert (
ret.append(None) len(tvt_set) == 1
if tvt_set[0].type is None: ), "Only one TVT set is allowed if type is not specified."
assert ( ret = ItemSet(
len(tvt_set) == 1 tuple(
), "Only one TVT set is allowed if type is not specified." read_data(data.path, data.format, data.in_memory)
ret.append( for data in tvt_set[0].data
ItemSet(
tuple(
read_data(data.path, data.format, data.in_memory)
for data in tvt_set[0].data
)
)
) )
else: )
data = {} else:
for tvt in tvt_set: data = {}
data[tvt.type] = ItemSet( for tvt in tvt_set:
tuple( data[tvt.type] = ItemSet(
read_data(data.path, data.format, data.in_memory) tuple(
for data in tvt.data read_data(data.path, data.format, data.in_memory)
) for data in tvt.data
) )
ret.append(ItemSetDict(data)) )
ret = ItemSetDict(data)
return ret return ret
...@@ -83,6 +83,6 @@ class OnDiskMetaData(pydantic.BaseModel): ...@@ -83,6 +83,6 @@ class OnDiskMetaData(pydantic.BaseModel):
num_labels: Optional[int] = None num_labels: Optional[int] = None
graph_topology: Optional[OnDiskGraphTopology] = None graph_topology: Optional[OnDiskGraphTopology] = None
feature_data: Optional[List[OnDiskFeatureData]] = [] feature_data: Optional[List[OnDiskFeatureData]] = []
train_sets: Optional[List[List[OnDiskTVTSet]]] = [] train_set: Optional[List[OnDiskTVTSet]] = []
validation_sets: Optional[List[List[OnDiskTVTSet]]] = [] validation_set: Optional[List[OnDiskTVTSet]] = []
test_sets: Optional[List[List[OnDiskTVTSet]]] = [] test_set: Optional[List[OnDiskTVTSet]] = []
import os
import tempfile
import numpy as np
import pydantic
import pytest import pytest
from dgl import graphbolt as gb from dgl import graphbolt as gb
...@@ -11,15 +5,15 @@ from dgl import graphbolt as gb ...@@ -11,15 +5,15 @@ from dgl import graphbolt as gb
def test_Dataset(): def test_Dataset():
dataset = gb.Dataset() dataset = gb.Dataset()
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
_ = dataset.train_sets() _ = dataset.train_set
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
_ = dataset.validation_sets() _ = dataset.validation_set
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
_ = dataset.test_sets() _ = dataset.test_set
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
_ = dataset.graph() _ = dataset.graph
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
_ = dataset.feature() _ = dataset.feature
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
_ = dataset.dataset_name _ = dataset.dataset_name
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
......
...@@ -20,11 +20,11 @@ def test_OnDiskDataset_TVTSet_exceptions(): ...@@ -20,11 +20,11 @@ def test_OnDiskDataset_TVTSet_exceptions():
# Case 1: ``format`` is invalid. # Case 1: ``format`` is invalid.
yaml_content = """ yaml_content = """
train_sets: train_set:
- - type: paper - type: paper
data: data:
- format: torch_invalid - format: torch_invalid
path: set/paper-train.pt path: set/paper-train.pt
""" """
yaml_file = os.path.join(test_dir, "test.yaml") yaml_file = os.path.join(test_dir, "test.yaml")
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
...@@ -34,15 +34,15 @@ def test_OnDiskDataset_TVTSet_exceptions(): ...@@ -34,15 +34,15 @@ def test_OnDiskDataset_TVTSet_exceptions():
# Case 2: ``type`` is not specified while multiple TVT sets are specified. # Case 2: ``type`` is not specified while multiple TVT sets are specified.
yaml_content = """ yaml_content = """
train_sets: train_set:
- - type: null - type: null
data: data:
- format: numpy - format: numpy
path: set/train.npy path: set/train.npy
- type: null - type: null
data: data:
- format: numpy - format: numpy
path: set/train.npy path: set/train.npy
""" """
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
f.write(yaml_content) f.write(yaml_content)
...@@ -82,32 +82,32 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -82,32 +82,32 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
# ``type`` is not specified or specified as ``null``. # ``type`` is not specified or specified as ``null``.
# ``in_memory`` could be ``true`` and ``false``. # ``in_memory`` could be ``true`` and ``false``.
yaml_content = f""" yaml_content = f"""
train_sets: train_set:
- - type: null - type: null
data: data:
- format: numpy - format: numpy
in_memory: true in_memory: true
path: {train_ids_path} path: {train_ids_path}
- format: numpy - format: numpy
in_memory: true in_memory: true
path: {train_labels_path} path: {train_labels_path}
validation_sets: validation_set:
- - data: - data:
- format: numpy - format: numpy
in_memory: true in_memory: true
path: {validation_ids_path} path: {validation_ids_path}
- format: numpy - format: numpy
in_memory: true in_memory: true
path: {validation_labels_path} path: {validation_labels_path}
test_sets: test_set:
- - type: null - type: null
data: data:
- format: numpy - format: numpy
in_memory: true in_memory: true
path: {test_ids_path} path: {test_ids_path}
- format: numpy - format: numpy
in_memory: true in_memory: true
path: {test_labels_path} path: {test_labels_path}
""" """
yaml_file = os.path.join(test_dir, "test.yaml") yaml_file = os.path.join(test_dir, "test.yaml")
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
...@@ -116,55 +116,49 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -116,55 +116,49 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(yaml_file)
# Verify train set. # Verify train set.
train_sets = dataset.train_sets train_set = dataset.train_set
assert len(train_sets) == 1 assert len(train_set) == 1000
for train_set in train_sets: assert isinstance(train_set, gb.ItemSet)
assert len(train_set) == 1000 for i, (id, label) in enumerate(train_set):
assert isinstance(train_set, gb.ItemSet) assert id == train_ids[i]
for i, (id, label) in enumerate(train_set): assert label == train_labels[i]
assert id == train_ids[i] train_set = None
assert label == train_labels[i]
train_sets = None
# Verify validation set. # Verify validation set.
validation_sets = dataset.validation_sets validation_set = dataset.validation_set
assert len(validation_sets) == 1 assert len(validation_set) == 1000
for validation_set in validation_sets: assert isinstance(validation_set, gb.ItemSet)
assert len(validation_set) == 1000 for i, (id, label) in enumerate(validation_set):
assert isinstance(validation_set, gb.ItemSet) assert id == validation_ids[i]
for i, (id, label) in enumerate(validation_set): assert label == validation_labels[i]
assert id == validation_ids[i] validation_set = None
assert label == validation_labels[i]
validation_sets = None
# Verify test set. # Verify test set.
test_sets = dataset.test_sets test_set = dataset.test_set
assert len(test_sets) == 1 assert len(test_set) == 1000
for test_set in test_sets: assert isinstance(test_set, gb.ItemSet)
assert len(test_set) == 1000 for i, (id, label) in enumerate(test_set):
assert isinstance(test_set, gb.ItemSet) assert id == test_ids[i]
for i, (id, label) in enumerate(test_set): assert label == test_labels[i]
assert id == test_ids[i] test_set = None
assert label == test_labels[i]
test_sets = None
dataset = None dataset = None
# Case 2: Some TVT sets are None. # Case 2: Some TVT sets are None.
yaml_content = f""" yaml_content = f"""
train_sets: train_set:
- - type: null - type: null
data: data:
- format: numpy - format: numpy
path: {train_ids_path} path: {train_ids_path}
""" """
yaml_file = os.path.join(test_dir, "test.yaml") yaml_file = os.path.join(test_dir, "test.yaml")
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
f.write(yaml_content) f.write(yaml_content)
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(yaml_file)
assert dataset.train_sets is not None assert dataset.train_set is not None
assert dataset.validation_sets is None assert dataset.validation_set is None
assert dataset.test_sets is None assert dataset.test_set is None
dataset = None dataset = None
...@@ -202,41 +196,41 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label(): ...@@ -202,41 +196,41 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
np.save(test_labels_path, test_labels) np.save(test_labels_path, test_labels)
yaml_content = f""" yaml_content = f"""
train_sets: train_set:
- - type: null - type: null
data: data:
- format: numpy - format: numpy
in_memory: true in_memory: true
path: {train_src_path} path: {train_src_path}
- format: numpy - format: numpy
in_memory: true in_memory: true
path: {train_dst_path} path: {train_dst_path}
- format: numpy - format: numpy
in_memory: true in_memory: true
path: {train_labels_path} path: {train_labels_path}
validation_sets: validation_set:
- - data: - data:
- format: numpy - format: numpy
in_memory: true in_memory: true
path: {validation_src_path} path: {validation_src_path}
- format: numpy - format: numpy
in_memory: true in_memory: true
path: {validation_dst_path} path: {validation_dst_path}
- format: numpy - format: numpy
in_memory: true in_memory: true
path: {validation_labels_path} path: {validation_labels_path}
test_sets: test_set:
- - type: null - type: null
data: data:
- format: numpy - format: numpy
in_memory: true in_memory: true
path: {test_src_path} path: {test_src_path}
- format: numpy - format: numpy
in_memory: true in_memory: true
path: {test_dst_path} path: {test_dst_path}
- format: numpy - format: numpy
in_memory: true in_memory: true
path: {test_labels_path} path: {test_labels_path}
""" """
yaml_file = os.path.join(test_dir, "test.yaml") yaml_file = os.path.join(test_dir, "test.yaml")
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
...@@ -245,40 +239,34 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label(): ...@@ -245,40 +239,34 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(yaml_file)
# Verify train set. # Verify train set.
train_sets = dataset.train_sets train_set = dataset.train_set
assert len(train_sets) == 1 assert len(train_set) == 1000
for train_set in train_sets: assert isinstance(train_set, gb.ItemSet)
assert len(train_set) == 1000 for i, (src, dst, label) in enumerate(train_set):
assert isinstance(train_set, gb.ItemSet) assert src == train_src[i]
for i, (src, dst, label) in enumerate(train_set): assert dst == train_dst[i]
assert src == train_src[i] assert label == train_labels[i]
assert dst == train_dst[i] train_set = None
assert label == train_labels[i]
train_sets = None
# Verify validation set. # Verify validation set.
validation_sets = dataset.validation_sets validation_set = dataset.validation_set
assert len(validation_sets) == 1 assert len(validation_set) == 1000
for validation_set in validation_sets: assert isinstance(validation_set, gb.ItemSet)
assert len(validation_set) == 1000 for i, (src, dst, label) in enumerate(validation_set):
assert isinstance(validation_set, gb.ItemSet) assert src == validation_src[i]
for i, (src, dst, label) in enumerate(validation_set): assert dst == validation_dst[i]
assert src == validation_src[i] assert label == validation_labels[i]
assert dst == validation_dst[i] validation_set = None
assert label == validation_labels[i]
validation_sets = None
# Verify test set. # Verify test set.
test_sets = dataset.test_sets test_set = dataset.test_set
assert len(test_sets) == 1 assert len(test_set) == 1000
for test_set in test_sets: assert isinstance(test_set, gb.ItemSet)
assert len(test_set) == 1000 for i, (src, dst, label) in enumerate(test_set):
assert isinstance(test_set, gb.ItemSet) assert src == test_src[i]
for i, (src, dst, label) in enumerate(test_set): assert dst == test_dst[i]
assert src == test_src[i] assert label == test_labels[i]
assert dst == test_dst[i] test_set = None
assert label == test_labels[i]
test_sets = None
dataset = None dataset = None
...@@ -320,41 +308,41 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs(): ...@@ -320,41 +308,41 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs():
np.save(test_neg_dst_path, test_neg_dst) np.save(test_neg_dst_path, test_neg_dst)
yaml_content = f""" yaml_content = f"""
train_sets: train_set:
- - type: null - type: null
data: data:
- format: numpy - format: numpy
in_memory: true in_memory: true
path: {train_src_path} path: {train_src_path}
- format: numpy - format: numpy
in_memory: true in_memory: true
path: {train_dst_path} path: {train_dst_path}
- format: numpy - format: numpy
in_memory: true in_memory: true
path: {train_neg_dst_path} path: {train_neg_dst_path}
validation_sets: validation_set:
- - data: - data:
- format: numpy - format: numpy
in_memory: true in_memory: true
path: {validation_src_path} path: {validation_src_path}
- format: numpy - format: numpy
in_memory: true in_memory: true
path: {validation_dst_path} path: {validation_dst_path}
- format: numpy - format: numpy
in_memory: true in_memory: true
path: {validation_neg_dst_path} path: {validation_neg_dst_path}
test_sets: test_set:
- - type: null - type: null
data: data:
- format: numpy - format: numpy
in_memory: true in_memory: true
path: {test_src_path} path: {test_src_path}
- format: numpy - format: numpy
in_memory: true in_memory: true
path: {test_dst_path} path: {test_dst_path}
- format: numpy - format: numpy
in_memory: true in_memory: true
path: {test_neg_dst_path} path: {test_neg_dst_path}
""" """
yaml_file = os.path.join(test_dir, "test.yaml") yaml_file = os.path.join(test_dir, "test.yaml")
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
...@@ -363,42 +351,34 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs(): ...@@ -363,42 +351,34 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs():
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(yaml_file)
# Verify train set. # Verify train set.
train_sets = dataset.train_sets train_set = dataset.train_set
assert len(train_sets) == 1 assert len(train_set) == 1000
for train_set in train_sets: assert isinstance(train_set, gb.ItemSet)
assert len(train_set) == 1000 for i, (src, dst, negs) in enumerate(train_set):
assert isinstance(train_set, gb.ItemSet) assert src == train_src[i]
for i, (src, dst, negs) in enumerate(train_set): assert dst == train_dst[i]
assert src == train_src[i] assert torch.equal(negs, torch.from_numpy(train_neg_dst[i]))
assert dst == train_dst[i] train_set = None
assert torch.equal(negs, torch.from_numpy(train_neg_dst[i]))
train_sets = None
# Verify validation set. # Verify validation set.
validation_sets = dataset.validation_sets validation_set = dataset.validation_set
assert len(validation_sets) == 1 assert len(validation_set) == 1000
for validation_set in validation_sets: assert isinstance(validation_set, gb.ItemSet)
assert len(validation_set) == 1000 for i, (src, dst, negs) in enumerate(validation_set):
assert isinstance(validation_set, gb.ItemSet) assert src == validation_src[i]
for i, (src, dst, negs) in enumerate(validation_set): assert dst == validation_dst[i]
assert src == validation_src[i] assert torch.equal(negs, torch.from_numpy(validation_neg_dst[i]))
assert dst == validation_dst[i] validation_set = None
assert torch.equal(
negs, torch.from_numpy(validation_neg_dst[i])
)
validation_sets = None
# Verify test set. # Verify test set.
test_sets = dataset.test_sets test_set = dataset.test_set
assert len(test_sets) == 1 assert len(test_set) == 1000
for test_set in test_sets: assert isinstance(test_set, gb.ItemSet)
assert len(test_set) == 1000 for i, (src, dst, negs) in enumerate(test_set):
assert isinstance(test_set, gb.ItemSet) assert src == test_src[i]
for i, (src, dst, negs) in enumerate(test_set): assert dst == test_dst[i]
assert src == test_src[i] assert torch.equal(negs, torch.from_numpy(test_neg_dst[i]))
assert dst == test_dst[i] test_set = None
assert torch.equal(negs, torch.from_numpy(test_neg_dst[i]))
test_sets = None
dataset = None dataset = None
...@@ -424,35 +404,35 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): ...@@ -424,35 +404,35 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
np.save(test_path, test_data) np.save(test_path, test_data)
yaml_content = f""" yaml_content = f"""
train_sets: train_set:
- - type: paper - type: paper
data: data:
- format: numpy - format: numpy
in_memory: true in_memory: true
path: {train_path} path: {train_path}
- - type: author - type: author
data: data:
- format: numpy - format: numpy
path: {train_path} path: {train_path}
validation_sets: validation_set:
- - type: paper - type: paper
data: data:
- format: numpy - format: numpy
path: {validation_path} path: {validation_path}
- - type: author - type: author
data: data:
- format: numpy - format: numpy
path: {validation_path} path: {validation_path}
test_sets: test_set:
- - type: paper - type: paper
data: data:
- format: numpy - format: numpy
in_memory: false in_memory: false
path: {test_path} path: {test_path}
- - type: author - type: author
data: data:
- format: numpy - format: numpy
path: {test_path} path: {test_path}
""" """
yaml_file = os.path.join(test_dir, "test.yaml") yaml_file = os.path.join(test_dir, "test.yaml")
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
...@@ -461,52 +441,46 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): ...@@ -461,52 +441,46 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(yaml_file)
# Verify train set. # Verify train set.
train_sets = dataset.train_sets train_set = dataset.train_set
assert len(train_sets) == 2 assert len(train_set) == 2000
for train_set in train_sets: assert isinstance(train_set, gb.ItemSetDict)
assert len(train_set) == 1000 for i, item in enumerate(train_set):
assert isinstance(train_set, gb.ItemSetDict) assert isinstance(item, dict)
for i, item in enumerate(train_set): assert len(item) == 1
assert isinstance(item, dict) key = list(item.keys())[0]
assert len(item) == 1 assert key in ["paper", "author"]
key = list(item.keys())[0] id, label = item[key]
assert key in ["paper", "author"] assert id == train_ids[i % 1000]
id, label = item[key] assert label == train_labels[i % 1000]
assert id == train_ids[i] train_set = None
assert label == train_labels[i]
train_sets = None
# Verify validation set. # Verify validation set.
validation_sets = dataset.validation_sets validation_set = dataset.validation_set
assert len(validation_sets) == 2 assert len(validation_set) == 2000
for validation_set in validation_sets: assert isinstance(validation_set, gb.ItemSetDict)
assert len(validation_set) == 1000 for i, item in enumerate(validation_set):
assert isinstance(train_set, gb.ItemSetDict) assert isinstance(item, dict)
for i, item in enumerate(validation_set): assert len(item) == 1
assert isinstance(item, dict) key = list(item.keys())[0]
assert len(item) == 1 assert key in ["paper", "author"]
key = list(item.keys())[0] id, label = item[key]
assert key in ["paper", "author"] assert id == validation_ids[i % 1000]
id, label = item[key] assert label == validation_labels[i % 1000]
assert id == validation_ids[i] validation_set = None
assert label == validation_labels[i]
validation_sets = None
# Verify test set. # Verify test set.
test_sets = dataset.test_sets test_set = dataset.test_set
assert len(test_sets) == 2 assert len(test_set) == 2000
for test_set in test_sets: assert isinstance(test_set, gb.ItemSetDict)
assert len(test_set) == 1000 for i, item in enumerate(test_set):
assert isinstance(train_set, gb.ItemSetDict) assert isinstance(item, dict)
for i, item in enumerate(test_set): assert len(item) == 1
assert isinstance(item, dict) key = list(item.keys())[0]
assert len(item) == 1 assert key in ["paper", "author"]
key = list(item.keys())[0] id, label = item[key]
assert key in ["paper", "author"] assert id == test_ids[i % 1000]
id, label = item[key] assert label == test_labels[i % 1000]
assert id == test_ids[i] test_set = None
assert label == test_labels[i]
test_sets = None
dataset = None dataset = None
...@@ -532,35 +506,35 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label(): ...@@ -532,35 +506,35 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
np.save(test_path, test_data) np.save(test_path, test_data)
yaml_content = f""" yaml_content = f"""
train_sets: train_set:
- - type: paper - type: paper
data: data:
- format: numpy - format: numpy
in_memory: true in_memory: true
path: {train_path} path: {train_path}
- - type: author - type: author
data: data:
- format: numpy - format: numpy
path: {train_path} path: {train_path}
validation_sets: validation_set:
- - type: paper - type: paper
data: data:
- format: numpy - format: numpy
path: {validation_path} path: {validation_path}
- - type: author - type: author
data: data:
- format: numpy - format: numpy
path: {validation_path} path: {validation_path}
test_sets: test_set:
- - type: paper - type: paper
data: data:
- format: numpy - format: numpy
in_memory: false in_memory: false
path: {test_path} path: {test_path}
- - type: author - type: author
data: data:
- format: numpy - format: numpy
path: {test_path} path: {test_path}
""" """
yaml_file = os.path.join(test_dir, "test.yaml") yaml_file = os.path.join(test_dir, "test.yaml")
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
...@@ -569,55 +543,49 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label(): ...@@ -569,55 +543,49 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(yaml_file)
# Verify train set. # Verify train set.
train_sets = dataset.train_sets train_set = dataset.train_set
assert len(train_sets) == 2 assert len(train_set) == 2000
for train_set in train_sets: assert isinstance(train_set, gb.ItemSetDict)
assert len(train_set) == 1000 for i, item in enumerate(train_set):
assert isinstance(train_set, gb.ItemSetDict) assert isinstance(item, dict)
for i, item in enumerate(train_set): assert len(item) == 1
assert isinstance(item, dict) key = list(item.keys())[0]
assert len(item) == 1 assert key in ["paper", "author"]
key = list(item.keys())[0] src, dst, label = item[key]
assert key in ["paper", "author"] assert src == train_pairs[0][i % 1000]
src, dst, label = item[key] assert dst == train_pairs[1][i % 1000]
assert src == train_pairs[0][i] assert label == train_labels[i % 1000]
assert dst == train_pairs[1][i] train_set = None
assert label == train_labels[i]
train_sets = None
# Verify validation set. # Verify validation set.
validation_sets = dataset.validation_sets validation_set = dataset.validation_set
assert len(validation_sets) == 2 assert len(validation_set) == 2000
for validation_set in validation_sets: assert isinstance(validation_set, gb.ItemSetDict)
assert len(validation_set) == 1000 for i, item in enumerate(validation_set):
assert isinstance(train_set, gb.ItemSetDict) assert isinstance(item, dict)
for i, item in enumerate(validation_set): assert len(item) == 1
assert isinstance(item, dict) key = list(item.keys())[0]
assert len(item) == 1 assert key in ["paper", "author"]
key = list(item.keys())[0] src, dst, label = item[key]
assert key in ["paper", "author"] assert src == validation_pairs[0][i % 1000]
src, dst, label = item[key] assert dst == validation_pairs[1][i % 1000]
assert src == validation_pairs[0][i] assert label == validation_labels[i % 1000]
assert dst == validation_pairs[1][i] validation_set = None
assert label == validation_labels[i]
validation_sets = None
# Verify test set. # Verify test set.
test_sets = dataset.test_sets test_set = dataset.test_set
assert len(test_sets) == 2 assert len(test_set) == 2000
for test_set in test_sets: assert isinstance(test_set, gb.ItemSetDict)
assert len(test_set) == 1000 for i, item in enumerate(test_set):
assert isinstance(train_set, gb.ItemSetDict) assert isinstance(item, dict)
for i, item in enumerate(test_set): assert len(item) == 1
assert isinstance(item, dict) key = list(item.keys())[0]
assert len(item) == 1 assert key in ["paper", "author"]
key = list(item.keys())[0] src, dst, label = item[key]
assert key in ["paper", "author"] assert src == test_pairs[0][i % 1000]
src, dst, label = item[key] assert dst == test_pairs[1][i % 1000]
assert src == test_pairs[0][i] assert label == test_labels[i % 1000]
assert dst == test_pairs[1][i] test_set = None
assert label == test_labels[i]
test_sets = None
dataset = None dataset = None
...@@ -995,21 +963,21 @@ def test_OnDiskDataset_preprocess_homogeneous(): ...@@ -995,21 +963,21 @@ def test_OnDiskDataset_preprocess_homogeneous():
format: numpy format: numpy
in_memory: false in_memory: false
path: data/node-feat.npy path: data/node-feat.npy
train_sets: train_set:
- - type_name: null - type_name: null
data: data:
- format: numpy - format: numpy
path: set/train.npy path: set/train.npy
validation_sets: validation_set:
- - type_name: null - type_name: null
data: data:
- format: numpy - format: numpy
path: set/validation.npy path: set/validation.npy
test_sets: test_set:
- - type_name: null - type_name: null
data: data:
- format: numpy - format: numpy
path: set/test.npy path: set/test.npy
""" """
yaml_file = os.path.join(test_dir, "test.yaml") yaml_file = os.path.join(test_dir, "test.yaml")
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment