Unverified Commit 4d3c01d6 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Bug Fix] Fix the case when reverse_edge is False for citation graphs (#3840)



* Update citation_graph.py

* Update

* Update

* Update
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 71157b05
......@@ -122,8 +122,12 @@ class CitationGraphDataset(DGLBuiltinDataset):
if self.reverse_edge:
graph = nx.DiGraph(nx.from_dict_of_lists(graph))
g = from_networkx(graph)
else:
graph = nx.Graph(nx.from_dict_of_lists(graph))
edges = list(graph.edges())
u, v = map(list, zip(*edges))
g = dgl_graph((u, v))
onehot_labels = np.vstack((ally, ty))
onehot_labels[test_idx_reorder, :] = onehot_labels[test_idx_range, :]
......@@ -137,9 +141,6 @@ class CitationGraphDataset(DGLBuiltinDataset):
val_mask = generate_mask_tensor(_sample_mask(idx_val, labels.shape[0]))
test_mask = generate_mask_tensor(_sample_mask(idx_test, labels.shape[0]))
self._graph = graph
g = from_networkx(graph)
g.ndata['train_mask'] = train_mask
g.ndata['val_mask'] = val_mask
g.ndata['test_mask'] = test_mask
......@@ -204,7 +205,6 @@ class CitationGraphDataset(DGLBuiltinDataset):
graph.ndata.pop('feat')
graph.ndata.pop('label')
graph = to_networkx(graph)
self._graph = nx.DiGraph(graph)
self._num_classes = info['num_classes']
self._g.ndata['train_mask'] = generate_mask_tensor(F.asnumpy(self._g.ndata['train_mask']))
......@@ -250,10 +250,6 @@ class CitationGraphDataset(DGLBuiltinDataset):
""" Citation graph is used in many examples
We preserve these properties for compatability.
"""
@property
def graph(self):
deprecate_property('dataset.graph', 'dataset[0]')
return self._graph
@property
def train_mask(self):
......
......@@ -96,7 +96,7 @@ from dgl.data import citation_graph as citegrh
data = citegrh.load_cora()
G = dgl.DGLGraph(data.graph)
G = data[0]
labels = th.tensor(data.labels)
# find all the nodes labeled with class 0
......
......@@ -306,7 +306,7 @@ def load_cora_data():
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
mask = torch.BoolTensor(data.train_mask)
g = DGLGraph(data.graph)
g = data[0]
return g, features, labels, mask
##############################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment