"docs/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "fe4d17fc43e9e0674d741712a68fceba620c2c8f"
Unverified Commit 9e532e7d authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Dataset] Chameleon (#5477)



* update

* update

* update

* lint

* update

* CI

* lint

* update doc

---------
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-36-188.ap-northeast-1.compute.internal>
parent 170203ae
...@@ -56,6 +56,7 @@ Datasets for node classification/regression tasks ...@@ -56,6 +56,7 @@ Datasets for node classification/regression tasks
YelpDataset YelpDataset
PATTERNDataset PATTERNDataset
CLUSTERDataset CLUSTERDataset
ChameleonDataset
Edge Prediction Datasets Edge Prediction Datasets
--------------------------------------- ---------------------------------------
......
...@@ -54,6 +54,7 @@ from .tu import LegacyTUDataset, TUDataset ...@@ -54,6 +54,7 @@ from .tu import LegacyTUDataset, TUDataset
from .utils import * from .utils import *
from .cluster import CLUSTERDataset from .cluster import CLUSTERDataset
from .pattern import PATTERNDataset from .pattern import PATTERNDataset
from .wiki_network import ChameleonDataset
from .wikics import WikiCSDataset from .wikics import WikiCSDataset
from .yelp import YelpDataset from .yelp import YelpDataset
from .zinc import ZINCDataset from .zinc import ZINCDataset
......
...@@ -6,7 +6,6 @@ from __future__ import absolute_import ...@@ -6,7 +6,6 @@ from __future__ import absolute_import
import abc import abc
import hashlib import hashlib
import os import os
import sys
import traceback import traceback
from ..utils import retry_method_with_fix from ..utils import retry_method_with_fix
...@@ -221,6 +220,15 @@ class DGLDataset(object): ...@@ -221,6 +220,15 @@ class DGLDataset(object):
hash_func.update(str(self._hash_key).encode("utf-8")) hash_func.update(str(self._hash_key).encode("utf-8"))
return hash_func.hexdigest()[:8] return hash_func.hexdigest()[:8]
def _get_hash_url_suffix(self):
"""Get the suffix based on the hash value of the url."""
if self._url is None:
return ""
else:
hash_func = hashlib.sha1()
hash_func.update(str(self._url).encode("utf-8"))
return "_" + hash_func.hexdigest()[:8]
@property @property
def url(self): def url(self):
r"""Get url to download the raw dataset.""" r"""Get url to download the raw dataset."""
...@@ -241,7 +249,9 @@ class DGLDataset(object): ...@@ -241,7 +249,9 @@ class DGLDataset(object):
r"""Directory contains the input data files. r"""Directory contains the input data files.
By default raw_path = os.path.join(self.raw_dir, self.name) By default raw_path = os.path.join(self.raw_dir, self.name)
""" """
return os.path.join(self.raw_dir, self.name) return os.path.join(
self.raw_dir, self.name + self._get_hash_url_suffix()
)
@property @property
def save_dir(self): def save_dir(self):
...@@ -251,7 +261,9 @@ class DGLDataset(object): ...@@ -251,7 +261,9 @@ class DGLDataset(object):
@property @property
def save_path(self): def save_path(self):
r"""Path to save the processed dataset.""" r"""Path to save the processed dataset."""
return os.path.join(self._save_dir, self.name) return os.path.join(
self.save_dir, self.name + self._get_hash_url_suffix()
)
@property @property
def verbose(self): def verbose(self):
......
"""QM7b dataset for graph property prediction (regression).""" """QM7b dataset for graph property prediction (regression)."""
import os import os
import numpy as np
from scipy import io from scipy import io
from .. import backend as F from .. import backend as F
from ..convert import graph as dgl_graph from ..convert import graph as dgl_graph
from .dgl_dataset import DGLDataset from .dgl_dataset import DGLDataset
from .utils import ( from .utils import check_sha1, download, load_graphs, save_graphs
check_sha1,
deprecate_property,
download,
load_graphs,
save_graphs,
)
class QM7bDataset(DGLDataset): class QM7bDataset(DGLDataset):
...@@ -93,7 +86,7 @@ class QM7bDataset(DGLDataset): ...@@ -93,7 +86,7 @@ class QM7bDataset(DGLDataset):
) )
def process(self): def process(self):
mat_path = self.raw_path + ".mat" mat_path = os.path.join(self.raw_dir, self.name + ".mat")
self.graphs, self.label = self._load_graph(mat_path) self.graphs, self.label = self._load_graph(mat_path)
def _load_graph(self, filename): def _load_graph(self, filename):
......
"""
Wikipedia page-page networks on the chameleon topic.
"""
import os
import numpy as np
from ..convert import graph
from .dgl_dataset import DGLBuiltinDataset
from .utils import _get_dgl_url
class WikiNetworkDataset(DGLBuiltinDataset):
r"""Wikipedia page-page networks from `Multi-scale Attributed
Node Embedding <https://arxiv.org/abs/1909.13021>`__ and later modified by
`Geom-GCN: Geometric Graph Convolutional Networks
<https://arxiv.org/abs/2002.05287>`
Parameters
----------
name : str
Name of the dataset.
raw_dir : str
Raw file directory to store the processed data.
force_reload : bool
Whether to always generate the data from scratch rather than load a
cached version.
verbose : bool
Whether to print progress information.
transform : callable
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
"""
def __init__(self, name, raw_dir, force_reload, verbose, transform):
url = _get_dgl_url(f"dataset/{name}.zip")
super(WikiNetworkDataset, self).__init__(
name=name,
url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def process(self):
"""Load and process the data."""
try:
import torch
except ImportError:
raise ModuleNotFoundError(
"This dataset requires PyTorch to be the backend."
)
# Process node features and labels.
with open(f"{self.raw_path}/out1_node_feature_label.txt", "r") as f:
data = f.read().split("\n")[1:-1]
features = [
[float(v) for v in r.split("\t")[1].split(",")] for r in data
]
features = torch.tensor(features, dtype=torch.float)
labels = [int(r.split("\t")[2]) for r in data]
self._num_classes = max(labels) + 1
labels = torch.tensor(labels, dtype=torch.long)
# Process graph structure.
with open(f"{self.raw_path}/out1_graph_edges.txt", "r") as f:
data = f.read().split("\n")[1:-1]
data = [[int(v) for v in r.split("\t")] for r in data]
dst, src = torch.tensor(data, dtype=torch.long).t().contiguous()
self._g = graph((src, dst), num_nodes=features.size(0))
self._g.ndata["feat"] = features
self._g.ndata["label"] = labels
# Process 10 train/val/test node splits.
train_masks, val_masks, test_masks = [], [], []
for i in range(10):
filepath = f"{self.raw_path}/{self.name}_split_0.6_0.2_{i}.npz"
f = np.load(filepath)
train_masks += [torch.from_numpy(f["train_mask"])]
val_masks += [torch.from_numpy(f["val_mask"])]
test_masks += [torch.from_numpy(f["test_mask"])]
self._g.ndata["train_mask"] = torch.stack(train_masks, dim=1).bool()
self._g.ndata["val_mask"] = torch.stack(val_masks, dim=1).bool()
self._g.ndata["test_mask"] = torch.stack(test_masks, dim=1).bool()
def has_cache(self):
return os.path.exists(self.raw_path)
def load(self):
self.process()
def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph."
if self._transform is None:
return self._g
else:
return self._transform(self._g)
def __len__(self):
return 1
@property
def num_classes(self):
return self._num_classes
class ChameleonDataset(WikiNetworkDataset):
r"""Wikipedia page-page network on chameleons from `Multi-scale Attributed
Node Embedding <https://arxiv.org/abs/1909.13021>`__ and later modified by
`Geom-GCN: Geometric Graph Convolutional Networks
<https://arxiv.org/abs/2002.05287>`
Nodes represent articles from the English Wikipedia, edges reflect mutual
links between them. Node features indicate the presence of particular nouns
in the articles. The nodes were classified into 5 classes in terms of their
average monthly traffic.
Statistics:
- Nodes: 2277
- Edges: 36101
- Number of Classes: 5
- 10 splits with 60/20/20 train/val/test ratio
- Train: 1092
- Val: 729
- Test: 456
Parameters
----------
raw_dir : str, optional
Raw file directory to store the processed data. Default: ~/.dgl/
force_reload : bool, optional
Whether to always generate the data from scratch rather than load a
cached version. Default: False
verbose : bool, optional
Whether to print progress information. Default: True
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access. Default: None
Attributes
----------
num_classes : int
Number of node classes
Notes
-----
The graph does not come with edges for both directions.
Examples
--------
>>> from dgl.data import ChameleonDataset
>>> dataset = ChameleonDataset()
>>> g = dataset[0]
>>> num_classes = dataset.num_classes
>>> # get node features
>>> feat = g.ndata["feat"]
>>> # get data split
>>> train_mask = g.ndata["train_mask"]
>>> val_mask = g.ndata["val_mask"]
>>> test_mask = g.ndata["test_mask"]
>>> # get labels
>>> label = g.ndata['label']
"""
def __init__(
self, raw_dir=None, force_reload=False, verbose=True, transform=None
):
super(ChameleonDataset, self).__init__(
name="chameleon",
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
import unittest
import backend as F
import dgl
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Datasets don't need to be tested on GPU.",
)
@unittest.skipIf(
dgl.backend.backend_name != "pytorch", reason="only supports pytorch"
)
def test_chameleon():
transform = dgl.AddSelfLoop(allow_duplicate=True)
# chameleon
g = dgl.data.ChameleonDataset(force_reload=True)[0]
assert g.num_nodes() == 2277
assert g.num_edges() == 36101
g2 = dgl.data.ChameleonDataset(force_reload=True, transform=transform)[0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment