Unverified Commit b347590a authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Dataset] Citation graph (#1902)



* citation graph

* GCN example use new citatoin dataset

* mxnet gat

* triger

* Fix

* Fix gat

* fix

* Fix tensorflow dgi

* Fix appnp, graphsage for mxnet

* fix monet and sgc for mxnet

* Fix tagcn

* update sgc, appnp
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
parent 4be4b134
...@@ -17,6 +17,7 @@ from .bitcoinotc import BitcoinOTC ...@@ -17,6 +17,7 @@ from .bitcoinotc import BitcoinOTC
from .gdelt import GDELT from .gdelt import GDELT
from .icews18 import ICEWS18 from .icews18 import ICEWS18
from .qm7b import QM7b from .qm7b import QM7b
from .citation_graph import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
def register_data_args(parser): def register_data_args(parser):
...@@ -27,7 +28,6 @@ def register_data_args(parser): ...@@ -27,7 +28,6 @@ def register_data_args(parser):
help= help=
"The input dataset. Can be cora, citeseer, pubmed, syn(synthetic dataset) or reddit" "The input dataset. Can be cora, citeseer, pubmed, syn(synthetic dataset) or reddit"
) )
citegrh.register_args(parser)
def load_data(args): def load_data(args):
...@@ -37,8 +37,6 @@ def load_data(args): ...@@ -37,8 +37,6 @@ def load_data(args):
return citegrh.load_citeseer() return citegrh.load_citeseer()
elif args.dataset == 'pubmed': elif args.dataset == 'pubmed':
return citegrh.load_pubmed() return citegrh.load_pubmed()
elif args.dataset == 'syn':
return citegrh.load_synthetic(args)
elif args.dataset is not None and args.dataset.startswith('reddit'): elif args.dataset is not None and args.dataset.startswith('reddit'):
return RedditDataset(self_loop=('self-loop' in args.dataset)) return RedditDataset(self_loop=('self-loop' in args.dataset))
else: else:
......
This diff is collapsed.
...@@ -249,6 +249,7 @@ class RDFGraphDataset(DGLBuiltinDataset): ...@@ -249,6 +249,7 @@ class RDFGraphDataset(DGLBuiltinDataset):
# save for compatability # save for compatability
self._train_idx = F.tensor(train_idx) self._train_idx = F.tensor(train_idx)
self._test_idx = F.tensor(test_idx) self._test_idx = F.tensor(test_idx)
self._labels = labels
def build_graph(self, mg, src, dst, ntid, etid, ntypes, etypes): def build_graph(self, mg, src, dst, ntid, etid, ntypes, etypes):
"""Build the graphs """Build the graphs
...@@ -638,17 +639,17 @@ class AIFBDataset(RDFGraphDataset): ...@@ -638,17 +639,17 @@ class AIFBDataset(RDFGraphDataset):
Return Return
------- -------
dgl.DGLGraph dgl.DGLGraph
graph structure, node features and labels. graph structure, node features and labels.
- ndata['train_mask']: mask for training node set - ndata['train_mask']: mask for training node set
- ndata['test_mask']: mask for testing node set - ndata['test_mask']: mask for testing node set
- ndata['labels']: mask for labels - ndata['labels']: mask for labels
""" """
return super(AIFBDataset, self).__getitem__(idx) return super(AIFBDataset, self).__getitem__(idx)
def __len__(self): def __len__(self):
r"""The number of graphs in the dataset.""" r"""The number of graphs in the dataset."""
return super(AIFBDataset, self).__len__(idx) return super(AIFBDataset, self).__len__()
def parse_entity(self, term): def parse_entity(self, term):
if isinstance(term, rdf.Literal): if isinstance(term, rdf.Literal):
...@@ -801,17 +802,17 @@ class MUTAGDataset(RDFGraphDataset): ...@@ -801,17 +802,17 @@ class MUTAGDataset(RDFGraphDataset):
Return Return
------- -------
dgl.DGLGraph dgl.DGLGraph
graph structure, node features and labels. graph structure, node features and labels.
- ndata['train_mask']: mask for training node set - ndata['train_mask']: mask for training node set
- ndata['test_mask']: mask for testing node set - ndata['test_mask']: mask for testing node set
- ndata['labels']: mask for labels - ndata['labels']: mask for labels
""" """
return super(MUTAGDataset, self).__getitem__(idx) return super(MUTAGDataset, self).__getitem__(idx)
def __len__(self): def __len__(self):
r"""The number of graphs in the dataset.""" r"""The number of graphs in the dataset."""
return super(MUTAGDataset, self).__len__(idx) return super(MUTAGDataset, self).__len__()
def parse_entity(self, term): def parse_entity(self, term):
if isinstance(term, rdf.Literal): if isinstance(term, rdf.Literal):
...@@ -980,17 +981,17 @@ class BGSDataset(RDFGraphDataset): ...@@ -980,17 +981,17 @@ class BGSDataset(RDFGraphDataset):
Return Return
------- -------
dgl.DGLGraph dgl.DGLGraph
graph structure, node features and labels. graph structure, node features and labels.
- ndata['train_mask']: mask for training node set - ndata['train_mask']: mask for training node set
- ndata['test_mask']: mask for testing node set - ndata['test_mask']: mask for testing node set
- ndata['labels']: mask for labels - ndata['labels']: mask for labels
""" """
return super(BGSDataset, self).__getitem__(idx) return super(BGSDataset, self).__getitem__(idx)
def __len__(self): def __len__(self):
r"""The number of graphs in the dataset.""" r"""The number of graphs in the dataset."""
return super(BGSDataset, self).__len__(idx) return super(BGSDataset, self).__len__()
def parse_entity(self, term): def parse_entity(self, term):
if isinstance(term, rdf.Literal): if isinstance(term, rdf.Literal):
...@@ -1155,17 +1156,17 @@ class AMDataset(RDFGraphDataset): ...@@ -1155,17 +1156,17 @@ class AMDataset(RDFGraphDataset):
Return Return
------- -------
dgl.DGLGraph dgl.DGLGraph
graph structure, node features and labels. graph structure, node features and labels.
- ndata['train_mask']: mask for training node set - ndata['train_mask']: mask for training node set
- ndata['test_mask']: mask for testing node set - ndata['test_mask']: mask for testing node set
- ndata['labels']: mask for labels - ndata['labels']: mask for labels
""" """
return super(AMDataset, self).__getitem__(idx) return super(AMDataset, self).__getitem__(idx)
def __len__(self): def __len__(self):
r"""The number of graphs in the dataset.""" r"""The number of graphs in the dataset."""
return super(AMDataset, self).__len__(idx) return super(AMDataset, self).__len__()
def parse_entity(self, term): def parse_entity(self, term):
if isinstance(term, rdf.Literal): if isinstance(term, rdf.Literal):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment