Unverified Commit cb2c4ec1 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

Fix (#1909)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
parent 56bbf9cb
...@@ -13,8 +13,6 @@ from .utils import generate_mask_tensor ...@@ -13,8 +13,6 @@ from .utils import generate_mask_tensor
from .utils import deprecate_property, deprecate_function from .utils import deprecate_property, deprecate_function
from ..utils import retry_method_with_fix from ..utils import retry_method_with_fix
from .. import backend as F from .. import backend as F
from ..graph import DGLGraph
from ..graph import batch as graph_batch
from ..convert import graph as dgl_graph from ..convert import graph as dgl_graph
class KnowledgeGraphDataset(DGLBuiltinDataset): class KnowledgeGraphDataset(DGLBuiltinDataset):
...@@ -140,34 +138,34 @@ class KnowledgeGraphDataset(DGLBuiltinDataset): ...@@ -140,34 +138,34 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
self._num_nodes = info['num_nodes'] self._num_nodes = info['num_nodes']
self._num_rels = info['num_rels'] self._num_rels = info['num_rels']
self._g = graphs[0] self._g = graphs[0]
train_mask = self._g.edata['train_mask'].numpy() train_mask = self._g.edata['train_edge_mask'].numpy()
val_mask = self._g.edata['val_mask'].numpy() val_mask = self._g.edata['valid_edge_mask'].numpy()
test_mask = self._g.edata['test_mask'].numpy() test_mask = self._g.edata['test_edge_mask'].numpy()
# convert mask tensor into bool tensor if possible # convert mask tensor into bool tensor if possible
self._g.ndata['train_edge_mask'] = generate_mask_tensor(self._g.ndata['train_edge_mask'].numpy()) self._g.edata['train_edge_mask'] = generate_mask_tensor(self._g.edata['train_edge_mask'].numpy())
self._g.ndata['valid_edge_mask'] = generate_mask_tensor(self._g.ndata['valid_edge_mask'].numpy()) self._g.edata['valid_edge_mask'] = generate_mask_tensor(self._g.edata['valid_edge_mask'].numpy())
self._g.ndata['test_edge_mask'] = generate_mask_tensor(self._g.ndata['test_edge_mask'].numpy()) self._g.edata['test_edge_mask'] = generate_mask_tensor(self._g.edata['test_edge_mask'].numpy())
self._g.ndata['train_mask'] = generate_mask_tensor(train_mask) self._g.edata['train_mask'] = generate_mask_tensor(self._g.edata['train_mask'].numpy())
self._g.ndata['val_mask'] = generate_mask_tensor(val_mask) self._g.edata['val_mask'] = generate_mask_tensor(self._g.edata['val_mask'].numpy())
self._g.ndata['test_mask'] = generate_mask_tensor(test_mask) self._g.edata['test_mask'] = generate_mask_tensor(self._g.edata['test_mask'].numpy())
# for compatability (with 0.4.x) generate train_idx, valid_idx and test_idx # for compatability (with 0.4.x) generate train_idx, valid_idx and test_idx
etype = self.g.edata['etype'].numpy() etype = self._g.edata['etype'].numpy()
self._etype = etype self._etype = etype
u, v = self._g.all_edges(form='uv') u, v = self._g.all_edges(form='uv')
u = u.numpy() u = u.numpy()
v = v.numpy() v = v.numpy()
train_idx = np.nonzero(train_mask==1) train_idx = np.nonzero(train_mask==1)
self._train = np.column_stack((u[train_idx], etype[train_idx], v[train_idx])) self._train = np.column_stack((u[train_idx], etype[train_idx], v[train_idx]))
valid_idx = np.nonzero(valid_mask==1) valid_idx = np.nonzero(val_mask==1)
self._valid = np.column_stack((u[valid_idx], etype[valid_idx], v[valid_idx])) self._valid = np.column_stack((u[valid_idx], etype[valid_idx], v[valid_idx]))
test_idx = np.nonzero(test_mask==1) test_idx = np.nonzero(test_mask==1)
self._test = np.column_stack((u[test_idx], etype[test_idx], v[test_idx])) self._test = np.column_stack((u[test_idx], etype[test_idx], v[test_idx]))
if self.verbose: if self.verbose:
print("# entities: {}".format(num_nodes)) print("# entities: {}".format(self.num_nodes))
print("# relations: {}".format(num_rels)) print("# relations: {}".format(self.num_rels))
print("# training edges: {}".format(len(train_idx))) print("# training edges: {}".format(len(train_idx)))
print("# validation edges: {}".format(len(valid_idx))) print("# validation edges: {}".format(len(valid_idx)))
print("# testing edges: {}".format(len(test_idx))) print("# testing edges: {}".format(len(test_idx)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment