Unverified Commit 1f4c0b71 authored by milesial's avatar milesial Committed by GitHub
Browse files

Fix processing of QM9EdgeDataset (#2801)

parent e36c5db6
...@@ -135,7 +135,7 @@ class QM9EdgeDataset(DGLDataset): ...@@ -135,7 +135,7 @@ class QM9EdgeDataset(DGLDataset):
raw_dir=None, raw_dir=None,
force_reload=False, force_reload=False,
verbose=True): verbose=True):
if label_keys == None: if label_keys is None:
self.label_keys = None self.label_keys = None
self.num_labels = 19 self.num_labels = 19
else: else:
...@@ -157,23 +157,7 @@ class QM9EdgeDataset(DGLDataset): ...@@ -157,23 +157,7 @@ class QM9EdgeDataset(DGLDataset):
download(self._url, path=file_path) download(self._url, path=file_path)
def process(self): def process(self):
self.load()
npz_path = f'{self.raw_dir}/qm9_edge.npz'
data_dict = np.load(npz_path, allow_pickle=True)
self.n_node = data_dict['n_node']
self.n_edge = data_dict['n_edge']
self.node_attr = data_dict['node_attr']
self.node_pos = data_dict['node_pos']
self.edge_attr = data_dict['edge_attr']
self.target = data_dict['target']
self.src = data_dict['src']
self.dst = data_dict['dst']
self.n_cumsum = np.concatenate([[0], np.cumsum(self.n_node)])
self.ne_cumsum = np.concatenate([[0], np.cumsum(self.n_edge)])
def has_cache(self): def has_cache(self):
npz_path = f'{self.raw_dir}/qm9_edge.npz' npz_path = f'{self.raw_dir}/qm9_edge.npz'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment