Unverified Commit 9711d75f authored by Songqing Zhang's avatar Songqing Zhang Committed by GitHub
Browse files

[Misc] Simplify the data path's assignment in python utils (#6058)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent 661f8177
...@@ -130,11 +130,19 @@ class QM9(QM9Dataset): ...@@ -130,11 +130,19 @@ class QM9(QM9Dataset):
verbose=verbose, verbose=verbose,
) )
@property
def graph_path(self):
return f"{self.save_path}/dgl_graph.bin"
@property
def line_graph_path(self):
return f"{self.save_path}/dgl_line_graph.bin"
def has_cache(self): def has_cache(self):
"""step 1, if True, goto step 5; else goto download(step 2), then step 3""" """step 1, if True, goto step 5; else goto download(step 2), then step 3"""
graph_path = f"{self.save_path}/dgl_graph.bin" return os.path.exists(self.graph_path) and os.path.exists(
line_graph_path = f"{self.save_path}/dgl_line_graph.bin" self.line_graph_path
return os.path.exists(graph_path) and os.path.exists(line_graph_path) )
def process(self): def process(self):
"""step 3""" """step 3"""
...@@ -197,17 +205,13 @@ class QM9(QM9Dataset): ...@@ -197,17 +205,13 @@ class QM9(QM9Dataset):
def save(self): def save(self):
"""step 4""" """step 4"""
graph_path = f"{self.save_path}/dgl_graph.bin" save_graphs(str(self.graph_path), self.graphs, self.label_dict)
line_graph_path = f"{self.save_path}/dgl_line_graph.bin" save_graphs(str(self.line_graph_path), self.line_graphs)
save_graphs(str(graph_path), self.graphs, self.label_dict)
save_graphs(str(line_graph_path), self.line_graphs)
def load(self): def load(self):
"""step 5""" """step 5"""
graph_path = f"{self.save_path}/dgl_graph.bin" self.graphs, label_dict = load_graphs(self.graph_path)
line_graph_path = f"{self.save_path}/dgl_line_graph.bin" self.line_graphs, _ = load_graphs(self.line_graph_path)
self.graphs, label_dict = load_graphs(graph_path)
self.line_graphs, _ = load_graphs(line_graph_path)
self.label = torch.stack( self.label = torch.stack(
[label_dict[key] for key in self.label_keys], dim=1 [label_dict[key] for key in self.label_keys], dim=1
) )
......
...@@ -68,21 +68,21 @@ class GASDataset(DGLBuiltinDataset): ...@@ -68,21 +68,21 @@ class GASDataset(DGLBuiltinDataset):
self.graph = hg self.graph = hg
@property
def graph_path(self):
return os.path.join(self.save_path, self.name + "_dgl_graph.bin")
def save(self): def save(self):
"""save the graph list and the labels""" """save the graph list and the labels"""
graph_path = os.path.join(self.save_path, self.name + "_dgl_graph.bin") save_graphs(str(self.graph_path), self.graph)
save_graphs(str(graph_path), self.graph)
def has_cache(self): def has_cache(self):
"""check whether there are processed data in `self.save_path`""" """check whether there are processed data in `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + "_dgl_graph.bin") return os.path.exists(self.graph_path)
return os.path.exists(graph_path)
def load(self): def load(self):
"""load processed data from directory `self.save_path`""" """load processed data from directory `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + "_dgl_graph.bin") graph, _ = load_graphs(str(self.graph_path))
graph, _ = load_graphs(str(graph_path))
self.graph = graph[0] self.graph = graph[0]
@property @property
......
...@@ -119,17 +119,18 @@ class BitcoinOTCDataset(DGLBuiltinDataset): ...@@ -119,17 +119,18 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
) )
self._graphs.append(g) self._graphs.append(g)
@property
def graph_path(self):
return os.path.join(self.save_path, "dgl_graph.bin")
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, "dgl_graph.bin") return os.path.exists(self.graph_path)
return os.path.exists(graph_path)
def save(self): def save(self):
graph_path = os.path.join(self.save_path, "dgl_graph.bin") save_graphs(self.graph_path, self.graphs)
save_graphs(graph_path, self.graphs)
def load(self): def load(self):
graph_path = os.path.join(self.save_path, "dgl_graph.bin") self._graphs = load_graphs(self.graph_path)[0]
self._graphs = load_graphs(graph_path)[0]
@property @property
def graphs(self): def graphs(self):
......
...@@ -211,27 +211,29 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -211,27 +211,29 @@ class CitationGraphDataset(DGLBuiltinDataset):
) )
) )
@property
def graph_path(self):
return os.path.join(self.save_path, self.save_name + ".bin")
@property
def info_path(self):
return os.path.join(self.save_path, self.save_name + ".pkl")
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, self.save_name + ".bin") if os.path.exists(self.graph_path) and os.path.exists(self.info_path):
info_path = os.path.join(self.save_path, self.save_name + ".pkl")
if os.path.exists(graph_path) and os.path.exists(info_path):
return True return True
return False return False
def save(self): def save(self):
"""save the graph list and the labels""" """save the graph list and the labels"""
graph_path = os.path.join(self.save_path, self.save_name + ".bin") save_graphs(str(self.graph_path), self._g)
info_path = os.path.join(self.save_path, self.save_name + ".pkl") save_info(str(self.info_path), {"num_classes": self.num_classes})
save_graphs(str(graph_path), self._g)
save_info(str(info_path), {"num_classes": self.num_classes})
def load(self): def load(self):
graph_path = os.path.join(self.save_path, self.save_name + ".bin") graphs, _ = load_graphs(str(self.graph_path))
info_path = os.path.join(self.save_path, self.save_name + ".pkl")
graphs, _ = load_graphs(str(graph_path))
info = load_info(str(info_path)) info = load_info(str(self.info_path))
graph = graphs[0] graph = graphs[0]
self._g = graph self._g = graph
# for compatability # for compatability
...@@ -854,26 +856,27 @@ class CoraBinary(DGLBuiltinDataset): ...@@ -854,26 +856,27 @@ class CoraBinary(DGLBuiltinDataset):
assert len(self.graphs) == len(self.pmpds) assert len(self.graphs) == len(self.pmpds)
assert len(self.graphs) == len(self.labels) assert len(self.graphs) == len(self.labels)
@property
def graph_path(self):
return os.path.join(self.save_path, self.save_name + ".bin")
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, self.save_name + ".bin") if os.path.exists(self.graph_path):
if os.path.exists(graph_path):
return True return True
return False return False
def save(self): def save(self):
"""save the graph list and the labels""" """save the graph list and the labels"""
graph_path = os.path.join(self.save_path, self.save_name + ".bin")
labels = {} labels = {}
for i, label in enumerate(self.labels): for i, label in enumerate(self.labels):
labels["{}".format(i)] = F.tensor(label) labels["{}".format(i)] = F.tensor(label)
save_graphs(str(graph_path), self.graphs, labels) save_graphs(str(self.graph_path), self.graphs, labels)
if self.verbose: if self.verbose:
print("Done saving data into cached files.") print("Done saving data into cached files.")
def load(self): def load(self):
graph_path = os.path.join(self.save_path, self.save_name + ".bin") self.graphs, labels = load_graphs(str(self.graph_path))
self.graphs, labels = load_graphs(str(graph_path))
self.labels = [] self.labels = []
for i in range(len(labels)): for i in range(len(labels)):
......
...@@ -179,11 +179,9 @@ class FakeNewsDataset(DGLBuiltinDataset): ...@@ -179,11 +179,9 @@ class FakeNewsDataset(DGLBuiltinDataset):
def save(self): def save(self):
"""save the graph list and the labels""" """save the graph list and the labels"""
graph_path = os.path.join(self.save_path, self.name + "_dgl_graph.bin") save_graphs(str(self.graph_path), self.graphs)
info_path = os.path.join(self.save_path, self.name + "_dgl_graph.pkl")
save_graphs(str(graph_path), self.graphs)
save_info( save_info(
info_path, self.info_path,
{ {
"label": self.labels, "label": self.labels,
"feature": self.feature, "feature": self.feature,
...@@ -193,19 +191,24 @@ class FakeNewsDataset(DGLBuiltinDataset): ...@@ -193,19 +191,24 @@ class FakeNewsDataset(DGLBuiltinDataset):
}, },
) )
@property
def graph_path(self):
return os.path.join(self.save_path, self.name + "_dgl_graph.bin")
@property
def info_path(self):
return os.path.join(self.save_path, self.name + "_dgl_graph.pkl")
def has_cache(self): def has_cache(self):
"""check whether there are processed data in `self.save_path`""" """check whether there are processed data in `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + "_dgl_graph.bin") return os.path.exists(self.graph_path) and os.path.exists(
info_path = os.path.join(self.save_path, self.name + "_dgl_graph.pkl") self.info_path
return os.path.exists(graph_path) and os.path.exists(info_path) )
def load(self): def load(self):
"""load processed data from directory `self.save_path`""" """load processed data from directory `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + "_dgl_graph.bin") graphs, _ = load_graphs(str(self.graph_path))
info_path = os.path.join(self.save_path, self.name + "_dgl_graph.pkl") info = load_info(str(self.info_path))
graphs, _ = load_graphs(str(graph_path))
info = load_info(str(info_path))
self.graphs = graphs self.graphs = graphs
self.labels = info["label"] self.labels = info["label"]
self.feature = info["feature"] self.feature = info["feature"]
......
...@@ -103,14 +103,16 @@ class GDELTDataset(DGLBuiltinDataset): ...@@ -103,14 +103,16 @@ class GDELTDataset(DGLBuiltinDataset):
self._start_time = self.time_index.min() self._start_time = self.time_index.min()
self._end_time = self.time_index.max() self._end_time = self.time_index.max()
@property
def info_path(self):
return os.path.join(self.save_path, self.mode + "_info.pkl")
def has_cache(self): def has_cache(self):
info_path = os.path.join(self.save_path, self.mode + "_info.pkl") return os.path.exists(self.info_path)
return os.path.exists(info_path)
def save(self): def save(self):
info_path = os.path.join(self.save_path, self.mode + "_info.pkl")
save_info( save_info(
info_path, self.info_path,
{ {
"data": self.data, "data": self.data,
"time_index": self.time_index, "time_index": self.time_index,
...@@ -120,8 +122,7 @@ class GDELTDataset(DGLBuiltinDataset): ...@@ -120,8 +122,7 @@ class GDELTDataset(DGLBuiltinDataset):
) )
def load(self): def load(self):
info_path = os.path.join(self.save_path, self.mode + "_info.pkl") info = load_info(self.info_path)
info = load_info(info_path)
self.data, self.time_index, self._start_time, self._end_time = ( self.data, self.time_index, self._start_time, self._end_time = (
info["data"], info["data"],
info["time_index"], info["time_index"],
......
...@@ -358,12 +358,6 @@ class GINDataset(DGLBuiltinDataset): ...@@ -358,12 +358,6 @@ class GINDataset(DGLBuiltinDataset):
) )
def save(self): def save(self):
graph_path = os.path.join(
self.save_path, "gin_{}_{}.bin".format(self.name, self.hash)
)
info_path = os.path.join(
self.save_path, "gin_{}_{}.pkl".format(self.name, self.hash)
)
label_dict = {"labels": self.labels} label_dict = {"labels": self.labels}
info_dict = { info_dict = {
"N": self.N, "N": self.N,
...@@ -380,18 +374,12 @@ class GINDataset(DGLBuiltinDataset): ...@@ -380,18 +374,12 @@ class GINDataset(DGLBuiltinDataset):
"elabel_dict": self.elabel_dict, "elabel_dict": self.elabel_dict,
"ndegree_dict": self.ndegree_dict, "ndegree_dict": self.ndegree_dict,
} }
save_graphs(str(graph_path), self.graphs, label_dict) save_graphs(str(self.graph_path), self.graphs, label_dict)
save_info(str(info_path), info_dict) save_info(str(self.info_path), info_dict)
def load(self): def load(self):
graph_path = os.path.join( graphs, label_dict = load_graphs(str(self.graph_path))
self.save_path, "gin_{}_{}.bin".format(self.name, self.hash) info_dict = load_info(str(self.info_path))
)
info_path = os.path.join(
self.save_path, "gin_{}_{}.pkl".format(self.name, self.hash)
)
graphs, label_dict = load_graphs(str(graph_path))
info_dict = load_info(str(info_path))
self.graphs = graphs self.graphs = graphs
self.labels = label_dict["labels"] self.labels = label_dict["labels"]
...@@ -410,14 +398,20 @@ class GINDataset(DGLBuiltinDataset): ...@@ -410,14 +398,20 @@ class GINDataset(DGLBuiltinDataset):
self.ndegree_dict = info_dict["ndegree_dict"] self.ndegree_dict = info_dict["ndegree_dict"]
self.degree_as_nlabel = info_dict["degree_as_nlabel"] self.degree_as_nlabel = info_dict["degree_as_nlabel"]
def has_cache(self): @property
graph_path = os.path.join( def graph_path(self):
return os.path.join(
self.save_path, "gin_{}_{}.bin".format(self.name, self.hash) self.save_path, "gin_{}_{}.bin".format(self.name, self.hash)
) )
info_path = os.path.join(
@property
def info_path(self):
return os.path.join(
self.save_path, "gin_{}_{}.pkl".format(self.name, self.hash) self.save_path, "gin_{}_{}.pkl".format(self.name, self.hash)
) )
if os.path.exists(graph_path) and os.path.exists(info_path):
def has_cache(self):
if os.path.exists(self.graph_path) and os.path.exists(self.info_path):
return True return True
return False return False
......
...@@ -145,10 +145,16 @@ class KnowledgeGraphDataset(DGLBuiltinDataset): ...@@ -145,10 +145,16 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
g.ndata["ntype"] = ntype g.ndata["ntype"] = ntype
self._g = g self._g = g
@property
def graph_path(self):
return os.path.join(self.save_path, self.save_name + ".bin")
@property
def info_path(self):
return os.path.join(self.save_path, self.save_name + ".pkl")
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, self.save_name + ".bin") if os.path.exists(self.graph_path) and os.path.exists(self.info_path):
info_path = os.path.join(self.save_path, self.save_name + ".pkl")
if os.path.exists(graph_path) and os.path.exists(info_path):
return True return True
return False return False
...@@ -165,20 +171,16 @@ class KnowledgeGraphDataset(DGLBuiltinDataset): ...@@ -165,20 +171,16 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
def save(self): def save(self):
"""save the graph list and the labels""" """save the graph list and the labels"""
graph_path = os.path.join(self.save_path, self.save_name + ".bin") save_graphs(str(self.graph_path), self._g)
info_path = os.path.join(self.save_path, self.save_name + ".pkl")
save_graphs(str(graph_path), self._g)
save_info( save_info(
str(info_path), str(self.info_path),
{"num_nodes": self.num_nodes, "num_rels": self.num_rels}, {"num_nodes": self.num_nodes, "num_rels": self.num_rels},
) )
def load(self): def load(self):
graph_path = os.path.join(self.save_path, self.save_name + ".bin") graphs, _ = load_graphs(str(self.graph_path))
info_path = os.path.join(self.save_path, self.save_name + ".pkl")
graphs, _ = load_graphs(str(graph_path))
info = load_info(str(info_path)) info = load_info(str(self.info_path))
self._num_nodes = info["num_nodes"] self._num_nodes = info["num_nodes"]
self._num_rels = info["num_rels"] self._num_rels = info["num_rels"]
self._g = graphs[0] self._g = graphs[0]
......
...@@ -86,17 +86,17 @@ class PATTERNDataset(DGLBuiltinDataset): ...@@ -86,17 +86,17 @@ class PATTERNDataset(DGLBuiltinDataset):
def process(self): def process(self):
self.load() self.load()
def has_cache(self): @property
graph_path = os.path.join( def graph_path(self):
return os.path.join(
self.save_path, "SBM_PATTERN_{}.bin".format(self.mode) self.save_path, "SBM_PATTERN_{}.bin".format(self.mode)
) )
return os.path.exists(graph_path)
def has_cache(self):
return os.path.exists(self.graph_path)
def load(self): def load(self):
graph_path = os.path.join( self._graphs, _ = load_graphs(self.graph_path)
self.save_path, "SBM_PATTERN_{}.bin".format(self.mode)
)
self._graphs, _ = load_graphs(graph_path)
@property @property
def num_classes(self): def num_classes(self):
......
...@@ -131,50 +131,41 @@ class PPIDataset(DGLBuiltinDataset): ...@@ -131,50 +131,41 @@ class PPIDataset(DGLBuiltinDataset):
) )
self.graphs.append(g) self.graphs.append(g)
def has_cache(self): @property
graph_list_path = os.path.join( def graph_list_path(self):
return os.path.join(
self.save_path, "{}_dgl_graph_list.bin".format(self.mode) self.save_path, "{}_dgl_graph_list.bin".format(self.mode)
) )
g_path = os.path.join(
@property
def g_path(self):
return os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.mode) self.save_path, "{}_dgl_graph.bin".format(self.mode)
) )
info_path = os.path.join(
self.save_path, "{}_info.pkl".format(self.mode) @property
) def info_path(self):
return os.path.join(self.save_path, "{}_info.pkl".format(self.mode))
def has_cache(self):
return ( return (
os.path.exists(graph_list_path) os.path.exists(self.graph_list_path)
and os.path.exists(g_path) and os.path.exists(self.g_path)
and os.path.exists(info_path) and os.path.exists(self.info_path)
) )
def save(self): def save(self):
graph_list_path = os.path.join( save_graphs(self.graph_list_path, self.graphs)
self.save_path, "{}_dgl_graph_list.bin".format(self.mode) save_graphs(self.g_path, self.graph)
) save_info(
g_path = os.path.join( self.info_path, {"labels": self._labels, "feats": self._feats}
self.save_path, "{}_dgl_graph.bin".format(self.mode)
) )
info_path = os.path.join(
self.save_path, "{}_info.pkl".format(self.mode)
)
save_graphs(graph_list_path, self.graphs)
save_graphs(g_path, self.graph)
save_info(info_path, {"labels": self._labels, "feats": self._feats})
def load(self): def load(self):
graph_list_path = os.path.join( self.graphs = load_graphs(self.graph_list_path)[0]
self.save_path, "{}_dgl_graph_list.bin".format(self.mode) g, _ = load_graphs(self.g_path)
)
g_path = os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.mode)
)
info_path = os.path.join(
self.save_path, "{}_info.pkl".format(self.mode)
)
self.graphs = load_graphs(graph_list_path)[0]
g, _ = load_graphs(g_path)
self.graph = g[0] self.graph = g[0]
info = load_info(info_path) info = load_info(self.info_path)
self._labels = info["labels"] self._labels = info["labels"]
self._feats = info["feats"] self._feats = info["feats"]
......
...@@ -184,20 +184,22 @@ class QM9EdgeDataset(DGLDataset): ...@@ -184,20 +184,22 @@ class QM9EdgeDataset(DGLDataset):
) )
def download(self): def download(self):
file_path = f"{self.raw_dir}/qm9_edge.npz" if not os.path.exists(self.npz_path):
if not os.path.exists(file_path): download(self._url, path=self.npz_path)
download(self._url, path=file_path)
def process(self): def process(self):
self.load() self.load()
@property
def npz_path(self):
return f"{self.raw_dir}/qm9_edge.npz"
def has_cache(self): def has_cache(self):
npz_path = f"{self.raw_dir}/qm9_edge.npz" return os.path.exists(self.npz_path)
return os.path.exists(npz_path)
def save(self): def save(self):
np.savez_compressed( np.savez_compressed(
f"{self.raw_dir}/qm9_edge.npz", self.npz_path,
n_node=self.n_node, n_node=self.n_node,
n_edge=self.n_edge, n_edge=self.n_edge,
node_attr=self.node_attr, node_attr=self.node_attr,
...@@ -209,8 +211,7 @@ class QM9EdgeDataset(DGLDataset): ...@@ -209,8 +211,7 @@ class QM9EdgeDataset(DGLDataset):
) )
def load(self): def load(self):
npz_path = f"{self.raw_dir}/qm9_edge.npz" data_dict = np.load(self.npz_path, allow_pickle=True)
data_dict = np.load(npz_path, allow_pickle=True)
self.n_node = data_dict["n_node"] self.n_node = data_dict["n_node"]
self.n_edge = data_dict["n_edge"] self.n_edge = data_dict["n_edge"]
......
...@@ -148,36 +148,32 @@ class SBMMixtureDataset(DGLDataset): ...@@ -148,36 +148,32 @@ class SBMMixtureDataset(DGLDataset):
self._line_graph_degrees = [in_degrees(lg) for lg in self._line_graphs] self._line_graph_degrees = [in_degrees(lg) for lg in self._line_graphs]
self._pm_pds = list(zip(*[g.edges() for g in self._graphs]))[0] self._pm_pds = list(zip(*[g.edges() for g in self._graphs]))[0]
def has_cache(self): @property
graph_path = os.path.join( def graph_path(self):
self.save_path, "graphs_{}.bin".format(self.hash) return os.path.join(self.save_path, "graphs_{}.bin".format(self.hash))
)
line_graph_path = os.path.join( @property
def line_graph_path(self):
return os.path.join(
self.save_path, "line_graphs_{}.bin".format(self.hash) self.save_path, "line_graphs_{}.bin".format(self.hash)
) )
info_path = os.path.join(
self.save_path, "info_{}.pkl".format(self.hash) @property
) def info_path(self):
return os.path.join(self.save_path, "info_{}.pkl".format(self.hash))
def has_cache(self):
return ( return (
os.path.exists(graph_path) os.path.exists(self.graph_path)
and os.path.exists(line_graph_path) and os.path.exists(self.line_graph_path)
and os.path.exists(info_path) and os.path.exists(self.info_path)
) )
def save(self): def save(self):
graph_path = os.path.join( save_graphs(self.graph_path, self._graphs)
self.save_path, "graphs_{}.bin".format(self.hash) save_graphs(self.line_graph_path, self._line_graphs)
)
line_graph_path = os.path.join(
self.save_path, "line_graphs_{}.bin".format(self.hash)
)
info_path = os.path.join(
self.save_path, "info_{}.pkl".format(self.hash)
)
save_graphs(graph_path, self._graphs)
save_graphs(line_graph_path, self._line_graphs)
save_info( save_info(
info_path, self.info_path,
{ {
"graph_degree": self._graph_degrees, "graph_degree": self._graph_degrees,
"line_graph_degree": self._line_graph_degrees, "line_graph_degree": self._line_graph_degrees,
...@@ -186,18 +182,9 @@ class SBMMixtureDataset(DGLDataset): ...@@ -186,18 +182,9 @@ class SBMMixtureDataset(DGLDataset):
) )
def load(self): def load(self):
graph_path = os.path.join( self._graphs, _ = load_graphs(self.graph_path)
self.save_path, "graphs_{}.bin".format(self.hash) self._line_graphs, _ = load_graphs(self.line_graph_path)
) info = load_info(self.info_path)
line_graph_path = os.path.join(
self.save_path, "line_graphs_{}.bin".format(self.hash)
)
info_path = os.path.join(
self.save_path, "info_{}.pkl".format(self.hash)
)
self._graphs, _ = load_graphs(graph_path)
self._line_graphs, _ = load_graphs(line_graph_path)
info = load_info(info_path)
self._graph_degrees = info["graph_degree"] self._graph_degrees = info["graph_degree"]
self._line_graph_degrees = info["line_graph_degree"] self._line_graph_degrees = info["line_graph_degree"]
self._pm_pds = info["pm_pds"] self._pm_pds = info["pm_pds"]
......
...@@ -221,27 +221,31 @@ class SSTDataset(DGLBuiltinDataset): ...@@ -221,27 +221,31 @@ class SSTDataset(DGLBuiltinDataset):
ret = from_networkx(g, node_attrs=["x", "y", "mask"]) ret = from_networkx(g, node_attrs=["x", "y", "mask"])
return ret return ret
@property
def graph_path(self):
return os.path.join(self.save_path, self.mode + "_dgl_graph.bin")
@property
def vocab_path(self):
return os.path.join(self.save_path, "vocab.pkl")
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, self.mode + "_dgl_graph.bin") return os.path.exists(self.graph_path) and os.path.exists(
vocab_path = os.path.join(self.save_path, "vocab.pkl") self.vocab_path
return os.path.exists(graph_path) and os.path.exists(vocab_path) )
def save(self): def save(self):
graph_path = os.path.join(self.save_path, self.mode + "_dgl_graph.bin") save_graphs(self.graph_path, self._trees)
save_graphs(graph_path, self._trees) save_info(self.vocab_path, {"vocab": self.vocab})
vocab_path = os.path.join(self.save_path, "vocab.pkl")
save_info(vocab_path, {"vocab": self.vocab})
if self.pretrained_emb: if self.pretrained_emb:
emb_path = os.path.join(self.save_path, "emb.pkl") emb_path = os.path.join(self.save_path, "emb.pkl")
save_info(emb_path, {"embed": self.pretrained_emb}) save_info(emb_path, {"embed": self.pretrained_emb})
def load(self): def load(self):
graph_path = os.path.join(self.save_path, self.mode + "_dgl_graph.bin")
vocab_path = os.path.join(self.save_path, "vocab.pkl")
emb_path = os.path.join(self.save_path, "emb.pkl") emb_path = os.path.join(self.save_path, "emb.pkl")
self._trees = load_graphs(graph_path)[0] self._trees = load_graphs(self.graph_path)[0]
self._vocab = load_info(vocab_path)["vocab"] self._vocab = load_info(self.vocab_path)["vocab"]
self._pretrained_emb = None self._pretrained_emb = None
if os.path.exists(emb_path): if os.path.exists(emb_path):
self._pretrained_emb = load_info(emb_path)["embed"] self._pretrained_emb = load_info(emb_path)["embed"]
......
...@@ -214,43 +214,37 @@ class LegacyTUDataset(DGLBuiltinDataset): ...@@ -214,43 +214,37 @@ class LegacyTUDataset(DGLBuiltinDataset):
self.graph_labels = F.tensor(self.graph_labels) self.graph_labels = F.tensor(self.graph_labels)
def save(self): def save(self):
graph_path = os.path.join(
self.save_path, "legacy_tu_{}_{}.bin".format(self.name, self.hash)
)
info_path = os.path.join(
self.save_path, "legacy_tu_{}_{}.pkl".format(self.name, self.hash)
)
label_dict = {"labels": self.graph_labels} label_dict = {"labels": self.graph_labels}
info_dict = { info_dict = {
"max_num_node": self.max_num_node, "max_num_node": self.max_num_node,
"num_labels": self.num_labels, "num_labels": self.num_labels,
} }
save_graphs(str(graph_path), self.graph_lists, label_dict) save_graphs(str(self.graph_path), self.graph_lists, label_dict)
save_info(str(info_path), info_dict) save_info(str(self.info_path), info_dict)
def load(self): def load(self):
graph_path = os.path.join( graphs, label_dict = load_graphs(str(self.graph_path))
self.save_path, "legacy_tu_{}_{}.bin".format(self.name, self.hash) info_dict = load_info(str(self.info_path))
)
info_path = os.path.join(
self.save_path, "legacy_tu_{}_{}.pkl".format(self.name, self.hash)
)
graphs, label_dict = load_graphs(str(graph_path))
info_dict = load_info(str(info_path))
self.graph_lists = graphs self.graph_lists = graphs
self.graph_labels = label_dict["labels"] self.graph_labels = label_dict["labels"]
self.max_num_node = info_dict["max_num_node"] self.max_num_node = info_dict["max_num_node"]
self.num_labels = info_dict["num_labels"] self.num_labels = info_dict["num_labels"]
def has_cache(self): @property
graph_path = os.path.join( def graph_path(self):
return os.path.join(
self.save_path, "legacy_tu_{}_{}.bin".format(self.name, self.hash) self.save_path, "legacy_tu_{}_{}.bin".format(self.name, self.hash)
) )
info_path = os.path.join(
@property
def info_path(self):
return os.path.join(
self.save_path, "legacy_tu_{}_{}.pkl".format(self.name, self.hash) self.save_path, "legacy_tu_{}_{}.pkl".format(self.name, self.hash)
) )
if os.path.exists(graph_path) and os.path.exists(info_path):
def has_cache(self):
if os.path.exists(self.graph_path) and os.path.exists(self.info_path):
return True return True
return False return False
...@@ -455,22 +449,26 @@ class TUDataset(DGLBuiltinDataset): ...@@ -455,22 +449,26 @@ class TUDataset(DGLBuiltinDataset):
self.graph_lists = [g.subgraph(node_idx) for node_idx in node_idx_list] self.graph_lists = [g.subgraph(node_idx) for node_idx in node_idx_list]
@property
def graph_path(self):
return os.path.join(self.save_path, "tu_{}.bin".format(self.name))
@property
def info_path(self):
return os.path.join(self.save_path, "tu_{}.pkl".format(self.name))
def save(self): def save(self):
graph_path = os.path.join(self.save_path, "tu_{}.bin".format(self.name))
info_path = os.path.join(self.save_path, "tu_{}.pkl".format(self.name))
label_dict = {"labels": self.graph_labels} label_dict = {"labels": self.graph_labels}
info_dict = { info_dict = {
"max_num_node": self.max_num_node, "max_num_node": self.max_num_node,
"num_labels": self.num_labels, "num_labels": self.num_labels,
} }
save_graphs(str(graph_path), self.graph_lists, label_dict) save_graphs(str(self.graph_path), self.graph_lists, label_dict)
save_info(str(info_path), info_dict) save_info(str(self.info_path), info_dict)
def load(self): def load(self):
graph_path = os.path.join(self.save_path, "tu_{}.bin".format(self.name)) graphs, label_dict = load_graphs(str(self.graph_path))
info_path = os.path.join(self.save_path, "tu_{}.pkl".format(self.name)) info_dict = load_info(str(self.info_path))
graphs, label_dict = load_graphs(str(graph_path))
info_dict = load_info(str(info_path))
self.graph_lists = graphs self.graph_lists = graphs
self.graph_labels = label_dict["labels"] self.graph_labels = label_dict["labels"]
...@@ -478,9 +476,7 @@ class TUDataset(DGLBuiltinDataset): ...@@ -478,9 +476,7 @@ class TUDataset(DGLBuiltinDataset):
self.num_labels = info_dict["num_labels"] self.num_labels = info_dict["num_labels"]
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, "tu_{}.bin".format(self.name)) if os.path.exists(self.graph_path) and os.path.exists(self.info_path):
info_path = os.path.join(self.save_path, "tu_{}.pkl".format(self.name))
if os.path.exists(graph_path) and os.path.exists(info_path):
return True return True
return False return False
......
...@@ -90,17 +90,15 @@ class ZINCDataset(DGLBuiltinDataset): ...@@ -90,17 +90,15 @@ class ZINCDataset(DGLBuiltinDataset):
def process(self): def process(self):
self.load() self.load()
@property
def graph_path(self):
return os.path.join(self.save_path, "ZincDGL_{}.bin".format(self.mode))
def has_cache(self): def has_cache(self):
graph_path = os.path.join( return os.path.exists(self.graph_path)
self.save_path, "ZincDGL_{}.bin".format(self.mode)
)
return os.path.exists(graph_path)
def load(self): def load(self):
graph_path = os.path.join( self._graphs, self._labels = load_graphs(self.graph_path)
self.save_path, "ZincDGL_{}.bin".format(self.mode)
)
self._graphs, self._labels = load_graphs(graph_path)
@property @property
def num_atom_types(self): def num_atom_types(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment