Unverified Commit 45ac5726 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[data] refine AsNodePredDataset and add tests for DGLCSVDataset (#3722)

* [data] refine AsNodePredDataset and add tests for DGLCSVDataset

* fix

* remove add_self_loop

* refine
parent fcd8ed9a
...@@ -5,6 +5,7 @@ import json ...@@ -5,6 +5,7 @@ import json
from .dgl_dataset import DGLDataset from .dgl_dataset import DGLDataset
from . import utils from . import utils
from .. import backend as F
__all__ = ['AsNodePredDataset'] __all__ = ['AsNodePredDataset']
...@@ -68,13 +69,15 @@ class AsNodePredDataset(DGLDataset): ...@@ -68,13 +69,15 @@ class AsNodePredDataset(DGLDataset):
self.g = dataset[0].clone() self.g = dataset[0].clone()
self.split_ratio = split_ratio self.split_ratio = split_ratio
self.target_ntype = target_ntype self.target_ntype = target_ntype
self.num_classes = dataset.num_classes self.num_classes = getattr(dataset, 'num_classes', None)
super().__init__(dataset.name + '-as-nodepred', **kwargs) super().__init__(dataset.name + '-as-nodepred', **kwargs)
def process(self): def process(self):
if 'label' not in self.g.nodes[self.target_ntype].data: if 'label' not in self.g.nodes[self.target_ntype].data:
raise ValueError("Missing node labels. Make sure labels are stored " raise ValueError("Missing node labels. Make sure labels are stored "
"under name 'label'.") "under name 'label'.")
if self.num_classes is None:
self.num_classes = len(F.unique(self.g.nodes[self.target_ntype].data['label']))
if self.verbose: if self.verbose:
print('Generating train/val/test masks...') print('Generating train/val/test masks...')
utils.add_nodepred_split(self, self.split_ratio, self.target_ntype) utils.add_nodepred_split(self, self.split_ratio, self.target_ntype)
......
...@@ -1067,6 +1067,48 @@ def test_as_nodepred2(): ...@@ -1067,6 +1067,48 @@ def test_as_nodepred2():
ds = data.AsNodePredDataset(data.AIFBDataset(), [0.1, 0.1, 0.8], 'Personen', verbose=True) ds = data.AsNodePredDataset(data.AIFBDataset(), [0.1, 0.1, 0.8], 'Personen', verbose=True)
assert F.sum(F.astype(ds[0].nodes['Personen'].data['train_mask'], F.int32), 0) == int(ds[0].num_nodes('Personen') * 0.1) assert F.sum(F.astype(ds[0].nodes['Personen'].data['train_mask'], F.int32), 0) == int(ds[0].num_nodes('Personen') * 0.1)
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_as_nodepred_csvdataset():
with tempfile.TemporaryDirectory() as test_dir:
# generate YAML/CSVs
meta_yaml_path = os.path.join(test_dir, "meta.yaml")
edges_csv_path = os.path.join(test_dir, "test_edges.csv")
nodes_csv_path = os.path.join(test_dir, "test_nodes.csv")
meta_yaml_data = {'version': '1.0.0', 'dataset_name': 'default_name',
'node_data': [{'file_name': os.path.basename(nodes_csv_path)
}],
'edge_data': [{'file_name': os.path.basename(edges_csv_path)
}],
}
with open(meta_yaml_path, 'w') as f:
yaml.dump(meta_yaml_data, f, sort_keys=False)
num_nodes = 100
num_edges = 500
num_dims = 3
num_classes = num_nodes
feat_ndata = np.random.rand(num_nodes, num_dims)
label_ndata = np.arange(num_classes)
df = pd.DataFrame({'node_id': np.arange(num_nodes),
'label': label_ndata,
'feat': [line.tolist() for line in feat_ndata],
})
df.to_csv(nodes_csv_path, index=False)
df = pd.DataFrame({'src_id': np.random.randint(num_nodes, size=num_edges),
'dst_id': np.random.randint(num_nodes, size=num_edges),
})
df.to_csv(edges_csv_path, index=False)
ds = data.DGLCSVDataset(test_dir, force_reload=True)
assert 'feat' in ds[0].ndata
assert 'label' in ds[0].ndata
assert 'train_mask' not in ds[0].ndata
assert not hasattr(ds[0], 'num_classes')
new_ds = data.AsNodePredDataset(ds, force_reload=True)
assert new_ds.num_classes == num_classes
assert 'feat' in new_ds[0].ndata
assert 'label' in new_ds[0].ndata
assert 'train_mask' in new_ds[0].ndata
if __name__ == '__main__': if __name__ == '__main__':
test_minigc() test_minigc()
test_gin() test_gin()
...@@ -1079,3 +1121,4 @@ if __name__ == '__main__': ...@@ -1079,3 +1121,4 @@ if __name__ == '__main__':
test_add_nodepred_split() test_add_nodepred_split()
test_as_nodepred1() test_as_nodepred1()
test_as_nodepred2() test_as_nodepred2()
test_as_nodepred_csvdataset()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment