Unverified Commit 4d3c01d6 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Bug Fix] Fix the case when reverse_edge is False for citation graphs (#3840)



* Update citation_graph.py

* Update

* Update

* Update
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 71157b05
...@@ -122,8 +122,12 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -122,8 +122,12 @@ class CitationGraphDataset(DGLBuiltinDataset):
if self.reverse_edge: if self.reverse_edge:
graph = nx.DiGraph(nx.from_dict_of_lists(graph)) graph = nx.DiGraph(nx.from_dict_of_lists(graph))
g = from_networkx(graph)
else: else:
graph = nx.Graph(nx.from_dict_of_lists(graph)) graph = nx.Graph(nx.from_dict_of_lists(graph))
edges = list(graph.edges())
u, v = map(list, zip(*edges))
g = dgl_graph((u, v))
onehot_labels = np.vstack((ally, ty)) onehot_labels = np.vstack((ally, ty))
onehot_labels[test_idx_reorder, :] = onehot_labels[test_idx_range, :] onehot_labels[test_idx_reorder, :] = onehot_labels[test_idx_range, :]
...@@ -137,9 +141,6 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -137,9 +141,6 @@ class CitationGraphDataset(DGLBuiltinDataset):
val_mask = generate_mask_tensor(_sample_mask(idx_val, labels.shape[0])) val_mask = generate_mask_tensor(_sample_mask(idx_val, labels.shape[0]))
test_mask = generate_mask_tensor(_sample_mask(idx_test, labels.shape[0])) test_mask = generate_mask_tensor(_sample_mask(idx_test, labels.shape[0]))
self._graph = graph
g = from_networkx(graph)
g.ndata['train_mask'] = train_mask g.ndata['train_mask'] = train_mask
g.ndata['val_mask'] = val_mask g.ndata['val_mask'] = val_mask
g.ndata['test_mask'] = test_mask g.ndata['test_mask'] = test_mask
...@@ -204,7 +205,6 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -204,7 +205,6 @@ class CitationGraphDataset(DGLBuiltinDataset):
graph.ndata.pop('feat') graph.ndata.pop('feat')
graph.ndata.pop('label') graph.ndata.pop('label')
graph = to_networkx(graph) graph = to_networkx(graph)
self._graph = nx.DiGraph(graph)
self._num_classes = info['num_classes'] self._num_classes = info['num_classes']
self._g.ndata['train_mask'] = generate_mask_tensor(F.asnumpy(self._g.ndata['train_mask'])) self._g.ndata['train_mask'] = generate_mask_tensor(F.asnumpy(self._g.ndata['train_mask']))
...@@ -250,10 +250,6 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -250,10 +250,6 @@ class CitationGraphDataset(DGLBuiltinDataset):
""" Citation graph is used in many examples """ Citation graph is used in many examples
We preserve these properties for compatability. We preserve these properties for compatability.
""" """
@property
def graph(self):
deprecate_property('dataset.graph', 'dataset[0]')
return self._graph
@property @property
def train_mask(self): def train_mask(self):
......
...@@ -96,7 +96,7 @@ from dgl.data import citation_graph as citegrh ...@@ -96,7 +96,7 @@ from dgl.data import citation_graph as citegrh
data = citegrh.load_cora() data = citegrh.load_cora()
G = dgl.DGLGraph(data.graph) G = data[0]
labels = th.tensor(data.labels) labels = th.tensor(data.labels)
# find all the nodes labeled with class 0 # find all the nodes labeled with class 0
......
...@@ -306,7 +306,7 @@ def load_cora_data(): ...@@ -306,7 +306,7 @@ def load_cora_data():
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
mask = torch.BoolTensor(data.train_mask) mask = torch.BoolTensor(data.train_mask)
g = DGLGraph(data.graph) g = data[0]
return g, features, labels, mask return g, features, labels, mask
############################################################################## ##############################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment