Unverified Commit 20d49771 authored by Yu Sun's avatar Yu Sun Committed by GitHub
Browse files

[Feature] Add reverse edges in CitationGraphDataset (#2588)



* [Feature] Add reverse edges in CitationGraphDataset

* fix bug

* fix bug

* fix bug

* Update python/dgl/data/citation_graph.py
Co-authored-by: default avatarXiangkun Hu <huxk_hit@qq.com>

* fix notes

* fix notes

* fix bugs

* solve requested changes

* fix bug

* fix bug
Co-authored-by: default avatarXiangkun Hu <huxk_hit@qq.com>
parent 3a19bbc5
...@@ -44,6 +44,8 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -44,6 +44,8 @@ class CitationGraphDataset(DGLBuiltinDataset):
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose: bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
reverse_edge: bool
Whether to add reverse edges in graph. Default: True.
""" """
_urls = { _urls = {
'cora_v2' : 'dataset/cora_v2.zip', 'cora_v2' : 'dataset/cora_v2.zip',
...@@ -51,7 +53,7 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -51,7 +53,7 @@ class CitationGraphDataset(DGLBuiltinDataset):
'pubmed' : 'dataset/pubmed.zip', 'pubmed' : 'dataset/pubmed.zip',
} }
def __init__(self, name, raw_dir=None, force_reload=False, verbose=True): def __init__(self, name, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
assert name.lower() in ['cora', 'citeseer', 'pubmed'] assert name.lower() in ['cora', 'citeseer', 'pubmed']
# Previously we use the pre-processing in pygcn (https://github.com/tkipf/pygcn) # Previously we use the pre-processing in pygcn (https://github.com/tkipf/pygcn)
...@@ -60,6 +62,8 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -60,6 +62,8 @@ class CitationGraphDataset(DGLBuiltinDataset):
name = 'cora_v2' name = 'cora_v2'
url = _get_dgl_url(self._urls[name]) url = _get_dgl_url(self._urls[name])
self._reverse_edge = reverse_edge
super(CitationGraphDataset, self).__init__(name, super(CitationGraphDataset, self).__init__(name,
url=url, url=url,
raw_dir=raw_dir, raw_dir=raw_dir,
...@@ -104,7 +108,11 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -104,7 +108,11 @@ class CitationGraphDataset(DGLBuiltinDataset):
features = sp.vstack((allx, tx)).tolil() features = sp.vstack((allx, tx)).tolil()
features[test_idx_reorder, :] = features[test_idx_range, :] features[test_idx_reorder, :] = features[test_idx_range, :]
if self.reverse_edge:
graph = nx.DiGraph(nx.from_dict_of_lists(graph)) graph = nx.DiGraph(nx.from_dict_of_lists(graph))
else:
graph = nx.Graph(nx.from_dict_of_lists(graph))
onehot_labels = np.vstack((ally, ty)) onehot_labels = np.vstack((ally, ty))
onehot_labels[test_idx_reorder, :] = onehot_labels[test_idx_range, :] onehot_labels[test_idx_reorder, :] = onehot_labels[test_idx_range, :]
...@@ -254,6 +262,11 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -254,6 +262,11 @@ class CitationGraphDataset(DGLBuiltinDataset):
deprecate_property('dataset.feat', 'g.ndata[\'feat\']') deprecate_property('dataset.feat', 'g.ndata[\'feat\']')
return self._g.ndata['feat'] return self._g.ndata['feat']
@property
def reverse_edge(self):
return self._reverse_edge
def _preprocess_features(features): def _preprocess_features(features):
"""Row-normalize feature matrix and convert to tuple representation""" """Row-normalize feature matrix and convert to tuple representation"""
rowsum = np.asarray(features.sum(1)) rowsum = np.asarray(features.sum(1))
...@@ -343,6 +356,8 @@ class CoraGraphDataset(CitationGraphDataset): ...@@ -343,6 +356,8 @@ class CoraGraphDataset(CitationGraphDataset):
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose: bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
reverse_edge: bool
Whether to add reverse edges in graph. Default: True.
Attributes Attributes
---------- ----------
...@@ -385,10 +400,10 @@ class CoraGraphDataset(CitationGraphDataset): ...@@ -385,10 +400,10 @@ class CoraGraphDataset(CitationGraphDataset):
>>> # Train, Validation and Test >>> # Train, Validation and Test
""" """
def __init__(self, raw_dir=None, force_reload=False, verbose=True): def __init__(self, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
name = 'cora' name = 'cora'
super(CoraGraphDataset, self).__init__(name, raw_dir, force_reload, verbose) super(CoraGraphDataset, self).__init__(name, raw_dir, force_reload, verbose, reverse_edge)
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Gets the graph object r"""Gets the graph object
...@@ -483,6 +498,8 @@ class CiteseerGraphDataset(CitationGraphDataset): ...@@ -483,6 +498,8 @@ class CiteseerGraphDataset(CitationGraphDataset):
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose: bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
reverse_edge: bool
Whether to add reverse edges in graph. Default: True.
Attributes Attributes
---------- ----------
...@@ -528,10 +545,10 @@ class CiteseerGraphDataset(CitationGraphDataset): ...@@ -528,10 +545,10 @@ class CiteseerGraphDataset(CitationGraphDataset):
>>> # Train, Validation and Test >>> # Train, Validation and Test
""" """
def __init__(self, raw_dir=None, force_reload=False, verbose=True): def __init__(self, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
name = 'citeseer' name = 'citeseer'
super(CiteseerGraphDataset, self).__init__(name, raw_dir, force_reload, verbose) super(CiteseerGraphDataset, self).__init__(name, raw_dir, force_reload, verbose, reverse_edge)
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Gets the graph object r"""Gets the graph object
...@@ -626,6 +643,8 @@ class PubmedGraphDataset(CitationGraphDataset): ...@@ -626,6 +643,8 @@ class PubmedGraphDataset(CitationGraphDataset):
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose: bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
reverse_edge: bool
Whether to add reverse edges in graph. Default: True.
Attributes Attributes
---------- ----------
...@@ -668,10 +687,10 @@ class PubmedGraphDataset(CitationGraphDataset): ...@@ -668,10 +687,10 @@ class PubmedGraphDataset(CitationGraphDataset):
>>> # Train, Validation and Test >>> # Train, Validation and Test
""" """
def __init__(self, raw_dir=None, force_reload=False, verbose=True): def __init__(self, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
name = 'pubmed' name = 'pubmed'
super(PubmedGraphDataset, self).__init__(name, raw_dir, force_reload, verbose) super(PubmedGraphDataset, self).__init__(name, raw_dir, force_reload, verbose, reverse_edge)
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Gets the graph object r"""Gets the graph object
...@@ -699,7 +718,7 @@ class PubmedGraphDataset(CitationGraphDataset): ...@@ -699,7 +718,7 @@ class PubmedGraphDataset(CitationGraphDataset):
r"""The number of graphs in the dataset.""" r"""The number of graphs in the dataset."""
return super(PubmedGraphDataset, self).__len__() return super(PubmedGraphDataset, self).__len__()
def load_cora(raw_dir=None, force_reload=False, verbose=True): def load_cora(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
"""Get CoraGraphDataset """Get CoraGraphDataset
Parameters Parameters
...@@ -711,15 +730,17 @@ def load_cora(raw_dir=None, force_reload=False, verbose=True): ...@@ -711,15 +730,17 @@ def load_cora(raw_dir=None, force_reload=False, verbose=True):
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose: bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
reverse_edge: bool
Whether to add reverse edges in graph. Default: True.
Return Return
------- -------
CoraGraphDataset CoraGraphDataset
""" """
data = CoraGraphDataset(raw_dir, force_reload, verbose) data = CoraGraphDataset(raw_dir, force_reload, verbose, reverse_edge)
return data return data
def load_citeseer(raw_dir=None, force_reload=False, verbose=True): def load_citeseer(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
"""Get CiteseerGraphDataset """Get CiteseerGraphDataset
Parameters Parameters
...@@ -731,15 +752,17 @@ def load_citeseer(raw_dir=None, force_reload=False, verbose=True): ...@@ -731,15 +752,17 @@ def load_citeseer(raw_dir=None, force_reload=False, verbose=True):
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose: bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
reverse_edge: bool
Whether to add reverse edges in graph. Default: True.
Return Return
------- -------
CiteseerGraphDataset CiteseerGraphDataset
""" """
data = CiteseerGraphDataset(raw_dir, force_reload, verbose) data = CiteseerGraphDataset(raw_dir, force_reload, verbose, reverse_edge)
return data return data
def load_pubmed(raw_dir=None, force_reload=False, verbose=True): def load_pubmed(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
"""Get PubmedGraphDataset """Get PubmedGraphDataset
Parameters Parameters
...@@ -751,12 +774,14 @@ def load_pubmed(raw_dir=None, force_reload=False, verbose=True): ...@@ -751,12 +774,14 @@ def load_pubmed(raw_dir=None, force_reload=False, verbose=True):
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose: bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
reverse_edge: bool
Whether to add reverse edges in graph. Default: True.
Return Return
------- -------
PubmedGraphDataset PubmedGraphDataset
""" """
data = PubmedGraphDataset(raw_dir, force_reload, verbose) data = PubmedGraphDataset(raw_dir, force_reload, verbose, reverse_edge)
return data return data
class CoraBinary(DGLBuiltinDataset): class CoraBinary(DGLBuiltinDataset):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment