Unverified Commit 20d49771 authored by Yu Sun's avatar Yu Sun Committed by GitHub
Browse files

[Feature] Add reverse edges in CitationGraphDataset (#2588)



* [Feature] Add reverse edges in CitationGraphDataset

* fix bug

* fix bug

* fix bug

* Update python/dgl/data/citation_graph.py
Co-authored-by: default avatarXiangkun Hu <huxk_hit@qq.com>

* fix notes

* fix notes

* fix bugs

* solve requested changes

* fix bug

* fix bug
Co-authored-by: default avatarXiangkun Hu <huxk_hit@qq.com>
parent 3a19bbc5
......@@ -43,7 +43,9 @@ class CitationGraphDataset(DGLBuiltinDataset):
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Whether to print out progress information. Default: True.
reverse_edge: bool
Whether to add reverse edges in graph. Default: True.
"""
_urls = {
'cora_v2' : 'dataset/cora_v2.zip',
......@@ -51,7 +53,7 @@ class CitationGraphDataset(DGLBuiltinDataset):
'pubmed' : 'dataset/pubmed.zip',
}
def __init__(self, name, raw_dir=None, force_reload=False, verbose=True):
def __init__(self, name, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
assert name.lower() in ['cora', 'citeseer', 'pubmed']
# Previously we use the pre-processing in pygcn (https://github.com/tkipf/pygcn)
......@@ -60,6 +62,8 @@ class CitationGraphDataset(DGLBuiltinDataset):
name = 'cora_v2'
url = _get_dgl_url(self._urls[name])
self._reverse_edge = reverse_edge
super(CitationGraphDataset, self).__init__(name,
url=url,
raw_dir=raw_dir,
......@@ -104,7 +108,11 @@ class CitationGraphDataset(DGLBuiltinDataset):
features = sp.vstack((allx, tx)).tolil()
features[test_idx_reorder, :] = features[test_idx_range, :]
graph = nx.DiGraph(nx.from_dict_of_lists(graph))
if self.reverse_edge:
graph = nx.DiGraph(nx.from_dict_of_lists(graph))
else:
graph = nx.Graph(nx.from_dict_of_lists(graph))
onehot_labels = np.vstack((ally, ty))
onehot_labels[test_idx_reorder, :] = onehot_labels[test_idx_range, :]
......@@ -254,6 +262,11 @@ class CitationGraphDataset(DGLBuiltinDataset):
deprecate_property('dataset.feat', 'g.ndata[\'feat\']')
return self._g.ndata['feat']
@property
def reverse_edge(self):
return self._reverse_edge
def _preprocess_features(features):
"""Row-normalize feature matrix and convert to tuple representation"""
rowsum = np.asarray(features.sum(1))
......@@ -343,6 +356,8 @@ class CoraGraphDataset(CitationGraphDataset):
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
reverse_edge: bool
Whether to add reverse edges in graph. Default: True.
Attributes
----------
......@@ -385,10 +400,10 @@ class CoraGraphDataset(CitationGraphDataset):
>>> # Train, Validation and Test
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=True):
def __init__(self, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
name = 'cora'
super(CoraGraphDataset, self).__init__(name, raw_dir, force_reload, verbose)
super(CoraGraphDataset, self).__init__(name, raw_dir, force_reload, verbose, reverse_edge)
def __getitem__(self, idx):
r"""Gets the graph object
......@@ -483,6 +498,8 @@ class CiteseerGraphDataset(CitationGraphDataset):
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
reverse_edge: bool
Whether to add reverse edges in graph. Default: True.
Attributes
----------
......@@ -528,10 +545,10 @@ class CiteseerGraphDataset(CitationGraphDataset):
>>> # Train, Validation and Test
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=True):
def __init__(self, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
name = 'citeseer'
super(CiteseerGraphDataset, self).__init__(name, raw_dir, force_reload, verbose)
super(CiteseerGraphDataset, self).__init__(name, raw_dir, force_reload, verbose, reverse_edge)
def __getitem__(self, idx):
r"""Gets the graph object
......@@ -626,6 +643,8 @@ class PubmedGraphDataset(CitationGraphDataset):
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
reverse_edge: bool
Whether to add reverse edges in graph. Default: True.
Attributes
----------
......@@ -668,10 +687,10 @@ class PubmedGraphDataset(CitationGraphDataset):
>>> # Train, Validation and Test
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=True):
def __init__(self, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
name = 'pubmed'
super(PubmedGraphDataset, self).__init__(name, raw_dir, force_reload, verbose)
super(PubmedGraphDataset, self).__init__(name, raw_dir, force_reload, verbose, reverse_edge)
def __getitem__(self, idx):
r"""Gets the graph object
......@@ -699,7 +718,7 @@ class PubmedGraphDataset(CitationGraphDataset):
r"""The number of graphs in the dataset."""
return super(PubmedGraphDataset, self).__len__()
def load_cora(raw_dir=None, force_reload=False, verbose=True):
def load_cora(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
"""Get CoraGraphDataset
Parameters
......@@ -711,15 +730,17 @@ def load_cora(raw_dir=None, force_reload=False, verbose=True):
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
reverse_edge: bool
Whether to add reverse edges in graph. Default: True.
Return
-------
CoraGraphDataset
"""
data = CoraGraphDataset(raw_dir, force_reload, verbose)
data = CoraGraphDataset(raw_dir, force_reload, verbose, reverse_edge)
return data
def load_citeseer(raw_dir=None, force_reload=False, verbose=True):
def load_citeseer(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
"""Get CiteseerGraphDataset
Parameters
......@@ -731,15 +752,17 @@ def load_citeseer(raw_dir=None, force_reload=False, verbose=True):
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
reverse_edge: bool
Whether to add reverse edges in graph. Default: True.
Return
-------
CiteseerGraphDataset
"""
data = CiteseerGraphDataset(raw_dir, force_reload, verbose)
data = CiteseerGraphDataset(raw_dir, force_reload, verbose, reverse_edge)
return data
def load_pubmed(raw_dir=None, force_reload=False, verbose=True):
def load_pubmed(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
"""Get PubmedGraphDataset
Parameters
......@@ -751,12 +774,14 @@ def load_pubmed(raw_dir=None, force_reload=False, verbose=True):
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
reverse_edge: bool
Whether to add reverse edges in graph. Default: True.
Return
-------
PubmedGraphDataset
"""
data = PubmedGraphDataset(raw_dir, force_reload, verbose)
data = PubmedGraphDataset(raw_dir, force_reload, verbose, reverse_edge)
return data
class CoraBinary(DGLBuiltinDataset):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment