"docs/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "e3d1d746f2f1c741ef4be9f8cf9cc55e8276806c"
Unverified Commit 9711d75f authored by Songqing Zhang's avatar Songqing Zhang Committed by GitHub
Browse files

[Misc] Simplify the data path's assignment in python utils (#6058)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent 661f8177
......@@ -130,11 +130,19 @@ class QM9(QM9Dataset):
verbose=verbose,
)
@property
def graph_path(self):
return f"{self.save_path}/dgl_graph.bin"
@property
def line_graph_path(self):
return f"{self.save_path}/dgl_line_graph.bin"
def has_cache(self):
"""step 1, if True, goto step 5; else goto download(step 2), then step 3"""
graph_path = f"{self.save_path}/dgl_graph.bin"
line_graph_path = f"{self.save_path}/dgl_line_graph.bin"
return os.path.exists(graph_path) and os.path.exists(line_graph_path)
return os.path.exists(self.graph_path) and os.path.exists(
self.line_graph_path
)
def process(self):
"""step 3"""
......@@ -197,17 +205,13 @@ class QM9(QM9Dataset):
def save(self):
"""step 4"""
graph_path = f"{self.save_path}/dgl_graph.bin"
line_graph_path = f"{self.save_path}/dgl_line_graph.bin"
save_graphs(str(graph_path), self.graphs, self.label_dict)
save_graphs(str(line_graph_path), self.line_graphs)
save_graphs(str(self.graph_path), self.graphs, self.label_dict)
save_graphs(str(self.line_graph_path), self.line_graphs)
def load(self):
"""step 5"""
graph_path = f"{self.save_path}/dgl_graph.bin"
line_graph_path = f"{self.save_path}/dgl_line_graph.bin"
self.graphs, label_dict = load_graphs(graph_path)
self.line_graphs, _ = load_graphs(line_graph_path)
self.graphs, label_dict = load_graphs(self.graph_path)
self.line_graphs, _ = load_graphs(self.line_graph_path)
self.label = torch.stack(
[label_dict[key] for key in self.label_keys], dim=1
)
......
......@@ -68,21 +68,21 @@ class GASDataset(DGLBuiltinDataset):
self.graph = hg
@property
def graph_path(self):
return os.path.join(self.save_path, self.name + "_dgl_graph.bin")
def save(self):
"""save the graph list and the labels"""
graph_path = os.path.join(self.save_path, self.name + "_dgl_graph.bin")
save_graphs(str(graph_path), self.graph)
save_graphs(str(self.graph_path), self.graph)
def has_cache(self):
"""check whether there are processed data in `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + "_dgl_graph.bin")
return os.path.exists(graph_path)
return os.path.exists(self.graph_path)
def load(self):
"""load processed data from directory `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + "_dgl_graph.bin")
graph, _ = load_graphs(str(graph_path))
graph, _ = load_graphs(str(self.graph_path))
self.graph = graph[0]
@property
......
......@@ -119,17 +119,18 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
)
self._graphs.append(g)
@property
def graph_path(self):
return os.path.join(self.save_path, "dgl_graph.bin")
def has_cache(self):
graph_path = os.path.join(self.save_path, "dgl_graph.bin")
return os.path.exists(graph_path)
return os.path.exists(self.graph_path)
def save(self):
graph_path = os.path.join(self.save_path, "dgl_graph.bin")
save_graphs(graph_path, self.graphs)
save_graphs(self.graph_path, self.graphs)
def load(self):
graph_path = os.path.join(self.save_path, "dgl_graph.bin")
self._graphs = load_graphs(graph_path)[0]
self._graphs = load_graphs(self.graph_path)[0]
@property
def graphs(self):
......
......@@ -211,27 +211,29 @@ class CitationGraphDataset(DGLBuiltinDataset):
)
)
@property
def graph_path(self):
return os.path.join(self.save_path, self.save_name + ".bin")
@property
def info_path(self):
return os.path.join(self.save_path, self.save_name + ".pkl")
def has_cache(self):
graph_path = os.path.join(self.save_path, self.save_name + ".bin")
info_path = os.path.join(self.save_path, self.save_name + ".pkl")
if os.path.exists(graph_path) and os.path.exists(info_path):
if os.path.exists(self.graph_path) and os.path.exists(self.info_path):
return True
return False
def save(self):
"""save the graph list and the labels"""
graph_path = os.path.join(self.save_path, self.save_name + ".bin")
info_path = os.path.join(self.save_path, self.save_name + ".pkl")
save_graphs(str(graph_path), self._g)
save_info(str(info_path), {"num_classes": self.num_classes})
save_graphs(str(self.graph_path), self._g)
save_info(str(self.info_path), {"num_classes": self.num_classes})
def load(self):
graph_path = os.path.join(self.save_path, self.save_name + ".bin")
info_path = os.path.join(self.save_path, self.save_name + ".pkl")
graphs, _ = load_graphs(str(graph_path))
graphs, _ = load_graphs(str(self.graph_path))
info = load_info(str(info_path))
info = load_info(str(self.info_path))
graph = graphs[0]
self._g = graph
# for compatability
......@@ -854,26 +856,27 @@ class CoraBinary(DGLBuiltinDataset):
assert len(self.graphs) == len(self.pmpds)
assert len(self.graphs) == len(self.labels)
@property
def graph_path(self):
return os.path.join(self.save_path, self.save_name + ".bin")
def has_cache(self):
graph_path = os.path.join(self.save_path, self.save_name + ".bin")
if os.path.exists(graph_path):
if os.path.exists(self.graph_path):
return True
return False
def save(self):
"""save the graph list and the labels"""
graph_path = os.path.join(self.save_path, self.save_name + ".bin")
labels = {}
for i, label in enumerate(self.labels):
labels["{}".format(i)] = F.tensor(label)
save_graphs(str(graph_path), self.graphs, labels)
save_graphs(str(self.graph_path), self.graphs, labels)
if self.verbose:
print("Done saving data into cached files.")
def load(self):
graph_path = os.path.join(self.save_path, self.save_name + ".bin")
self.graphs, labels = load_graphs(str(graph_path))
self.graphs, labels = load_graphs(str(self.graph_path))
self.labels = []
for i in range(len(labels)):
......
......@@ -179,11 +179,9 @@ class FakeNewsDataset(DGLBuiltinDataset):
def save(self):
"""save the graph list and the labels"""
graph_path = os.path.join(self.save_path, self.name + "_dgl_graph.bin")
info_path = os.path.join(self.save_path, self.name + "_dgl_graph.pkl")
save_graphs(str(graph_path), self.graphs)
save_graphs(str(self.graph_path), self.graphs)
save_info(
info_path,
self.info_path,
{
"label": self.labels,
"feature": self.feature,
......@@ -193,19 +191,24 @@ class FakeNewsDataset(DGLBuiltinDataset):
},
)
@property
def graph_path(self):
return os.path.join(self.save_path, self.name + "_dgl_graph.bin")
@property
def info_path(self):
return os.path.join(self.save_path, self.name + "_dgl_graph.pkl")
def has_cache(self):
"""check whether there are processed data in `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + "_dgl_graph.bin")
info_path = os.path.join(self.save_path, self.name + "_dgl_graph.pkl")
return os.path.exists(graph_path) and os.path.exists(info_path)
return os.path.exists(self.graph_path) and os.path.exists(
self.info_path
)
def load(self):
"""load processed data from directory `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + "_dgl_graph.bin")
info_path = os.path.join(self.save_path, self.name + "_dgl_graph.pkl")
graphs, _ = load_graphs(str(graph_path))
info = load_info(str(info_path))
graphs, _ = load_graphs(str(self.graph_path))
info = load_info(str(self.info_path))
self.graphs = graphs
self.labels = info["label"]
self.feature = info["feature"]
......
......@@ -103,14 +103,16 @@ class GDELTDataset(DGLBuiltinDataset):
self._start_time = self.time_index.min()
self._end_time = self.time_index.max()
@property
def info_path(self):
return os.path.join(self.save_path, self.mode + "_info.pkl")
def has_cache(self):
info_path = os.path.join(self.save_path, self.mode + "_info.pkl")
return os.path.exists(info_path)
return os.path.exists(self.info_path)
def save(self):
info_path = os.path.join(self.save_path, self.mode + "_info.pkl")
save_info(
info_path,
self.info_path,
{
"data": self.data,
"time_index": self.time_index,
......@@ -120,8 +122,7 @@ class GDELTDataset(DGLBuiltinDataset):
)
def load(self):
info_path = os.path.join(self.save_path, self.mode + "_info.pkl")
info = load_info(info_path)
info = load_info(self.info_path)
self.data, self.time_index, self._start_time, self._end_time = (
info["data"],
info["time_index"],
......
......@@ -358,12 +358,6 @@ class GINDataset(DGLBuiltinDataset):
)
def save(self):
graph_path = os.path.join(
self.save_path, "gin_{}_{}.bin".format(self.name, self.hash)
)
info_path = os.path.join(
self.save_path, "gin_{}_{}.pkl".format(self.name, self.hash)
)
label_dict = {"labels": self.labels}
info_dict = {
"N": self.N,
......@@ -380,18 +374,12 @@ class GINDataset(DGLBuiltinDataset):
"elabel_dict": self.elabel_dict,
"ndegree_dict": self.ndegree_dict,
}
save_graphs(str(graph_path), self.graphs, label_dict)
save_info(str(info_path), info_dict)
save_graphs(str(self.graph_path), self.graphs, label_dict)
save_info(str(self.info_path), info_dict)
def load(self):
graph_path = os.path.join(
self.save_path, "gin_{}_{}.bin".format(self.name, self.hash)
)
info_path = os.path.join(
self.save_path, "gin_{}_{}.pkl".format(self.name, self.hash)
)
graphs, label_dict = load_graphs(str(graph_path))
info_dict = load_info(str(info_path))
graphs, label_dict = load_graphs(str(self.graph_path))
info_dict = load_info(str(self.info_path))
self.graphs = graphs
self.labels = label_dict["labels"]
......@@ -410,14 +398,20 @@ class GINDataset(DGLBuiltinDataset):
self.ndegree_dict = info_dict["ndegree_dict"]
self.degree_as_nlabel = info_dict["degree_as_nlabel"]
def has_cache(self):
graph_path = os.path.join(
@property
def graph_path(self):
return os.path.join(
self.save_path, "gin_{}_{}.bin".format(self.name, self.hash)
)
info_path = os.path.join(
@property
def info_path(self):
return os.path.join(
self.save_path, "gin_{}_{}.pkl".format(self.name, self.hash)
)
if os.path.exists(graph_path) and os.path.exists(info_path):
def has_cache(self):
if os.path.exists(self.graph_path) and os.path.exists(self.info_path):
return True
return False
......
......@@ -145,10 +145,16 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
g.ndata["ntype"] = ntype
self._g = g
@property
def graph_path(self):
return os.path.join(self.save_path, self.save_name + ".bin")
@property
def info_path(self):
return os.path.join(self.save_path, self.save_name + ".pkl")
def has_cache(self):
graph_path = os.path.join(self.save_path, self.save_name + ".bin")
info_path = os.path.join(self.save_path, self.save_name + ".pkl")
if os.path.exists(graph_path) and os.path.exists(info_path):
if os.path.exists(self.graph_path) and os.path.exists(self.info_path):
return True
return False
......@@ -165,20 +171,16 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
def save(self):
"""save the graph list and the labels"""
graph_path = os.path.join(self.save_path, self.save_name + ".bin")
info_path = os.path.join(self.save_path, self.save_name + ".pkl")
save_graphs(str(graph_path), self._g)
save_graphs(str(self.graph_path), self._g)
save_info(
str(info_path),
str(self.info_path),
{"num_nodes": self.num_nodes, "num_rels": self.num_rels},
)
def load(self):
graph_path = os.path.join(self.save_path, self.save_name + ".bin")
info_path = os.path.join(self.save_path, self.save_name + ".pkl")
graphs, _ = load_graphs(str(graph_path))
graphs, _ = load_graphs(str(self.graph_path))
info = load_info(str(info_path))
info = load_info(str(self.info_path))
self._num_nodes = info["num_nodes"]
self._num_rels = info["num_rels"]
self._g = graphs[0]
......
......@@ -86,17 +86,17 @@ class PATTERNDataset(DGLBuiltinDataset):
def process(self):
self.load()
def has_cache(self):
graph_path = os.path.join(
@property
def graph_path(self):
return os.path.join(
self.save_path, "SBM_PATTERN_{}.bin".format(self.mode)
)
return os.path.exists(graph_path)
def has_cache(self):
return os.path.exists(self.graph_path)
def load(self):
graph_path = os.path.join(
self.save_path, "SBM_PATTERN_{}.bin".format(self.mode)
)
self._graphs, _ = load_graphs(graph_path)
self._graphs, _ = load_graphs(self.graph_path)
@property
def num_classes(self):
......
......@@ -131,50 +131,41 @@ class PPIDataset(DGLBuiltinDataset):
)
self.graphs.append(g)
def has_cache(self):
graph_list_path = os.path.join(
@property
def graph_list_path(self):
return os.path.join(
self.save_path, "{}_dgl_graph_list.bin".format(self.mode)
)
g_path = os.path.join(
@property
def g_path(self):
return os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.mode)
)
info_path = os.path.join(
self.save_path, "{}_info.pkl".format(self.mode)
)
@property
def info_path(self):
return os.path.join(self.save_path, "{}_info.pkl".format(self.mode))
def has_cache(self):
return (
os.path.exists(graph_list_path)
and os.path.exists(g_path)
and os.path.exists(info_path)
os.path.exists(self.graph_list_path)
and os.path.exists(self.g_path)
and os.path.exists(self.info_path)
)
def save(self):
graph_list_path = os.path.join(
self.save_path, "{}_dgl_graph_list.bin".format(self.mode)
)
g_path = os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.mode)
save_graphs(self.graph_list_path, self.graphs)
save_graphs(self.g_path, self.graph)
save_info(
self.info_path, {"labels": self._labels, "feats": self._feats}
)
info_path = os.path.join(
self.save_path, "{}_info.pkl".format(self.mode)
)
save_graphs(graph_list_path, self.graphs)
save_graphs(g_path, self.graph)
save_info(info_path, {"labels": self._labels, "feats": self._feats})
def load(self):
graph_list_path = os.path.join(
self.save_path, "{}_dgl_graph_list.bin".format(self.mode)
)
g_path = os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.mode)
)
info_path = os.path.join(
self.save_path, "{}_info.pkl".format(self.mode)
)
self.graphs = load_graphs(graph_list_path)[0]
g, _ = load_graphs(g_path)
self.graphs = load_graphs(self.graph_list_path)[0]
g, _ = load_graphs(self.g_path)
self.graph = g[0]
info = load_info(info_path)
info = load_info(self.info_path)
self._labels = info["labels"]
self._feats = info["feats"]
......
......@@ -184,20 +184,22 @@ class QM9EdgeDataset(DGLDataset):
)
def download(self):
file_path = f"{self.raw_dir}/qm9_edge.npz"
if not os.path.exists(file_path):
download(self._url, path=file_path)
if not os.path.exists(self.npz_path):
download(self._url, path=self.npz_path)
def process(self):
self.load()
@property
def npz_path(self):
return f"{self.raw_dir}/qm9_edge.npz"
def has_cache(self):
npz_path = f"{self.raw_dir}/qm9_edge.npz"
return os.path.exists(npz_path)
return os.path.exists(self.npz_path)
def save(self):
np.savez_compressed(
f"{self.raw_dir}/qm9_edge.npz",
self.npz_path,
n_node=self.n_node,
n_edge=self.n_edge,
node_attr=self.node_attr,
......@@ -209,8 +211,7 @@ class QM9EdgeDataset(DGLDataset):
)
def load(self):
npz_path = f"{self.raw_dir}/qm9_edge.npz"
data_dict = np.load(npz_path, allow_pickle=True)
data_dict = np.load(self.npz_path, allow_pickle=True)
self.n_node = data_dict["n_node"]
self.n_edge = data_dict["n_edge"]
......
......@@ -148,36 +148,32 @@ class SBMMixtureDataset(DGLDataset):
self._line_graph_degrees = [in_degrees(lg) for lg in self._line_graphs]
self._pm_pds = list(zip(*[g.edges() for g in self._graphs]))[0]
def has_cache(self):
graph_path = os.path.join(
self.save_path, "graphs_{}.bin".format(self.hash)
)
line_graph_path = os.path.join(
@property
def graph_path(self):
return os.path.join(self.save_path, "graphs_{}.bin".format(self.hash))
@property
def line_graph_path(self):
return os.path.join(
self.save_path, "line_graphs_{}.bin".format(self.hash)
)
info_path = os.path.join(
self.save_path, "info_{}.pkl".format(self.hash)
)
@property
def info_path(self):
return os.path.join(self.save_path, "info_{}.pkl".format(self.hash))
def has_cache(self):
return (
os.path.exists(graph_path)
and os.path.exists(line_graph_path)
and os.path.exists(info_path)
os.path.exists(self.graph_path)
and os.path.exists(self.line_graph_path)
and os.path.exists(self.info_path)
)
def save(self):
graph_path = os.path.join(
self.save_path, "graphs_{}.bin".format(self.hash)
)
line_graph_path = os.path.join(
self.save_path, "line_graphs_{}.bin".format(self.hash)
)
info_path = os.path.join(
self.save_path, "info_{}.pkl".format(self.hash)
)
save_graphs(graph_path, self._graphs)
save_graphs(line_graph_path, self._line_graphs)
save_graphs(self.graph_path, self._graphs)
save_graphs(self.line_graph_path, self._line_graphs)
save_info(
info_path,
self.info_path,
{
"graph_degree": self._graph_degrees,
"line_graph_degree": self._line_graph_degrees,
......@@ -186,18 +182,9 @@ class SBMMixtureDataset(DGLDataset):
)
def load(self):
graph_path = os.path.join(
self.save_path, "graphs_{}.bin".format(self.hash)
)
line_graph_path = os.path.join(
self.save_path, "line_graphs_{}.bin".format(self.hash)
)
info_path = os.path.join(
self.save_path, "info_{}.pkl".format(self.hash)
)
self._graphs, _ = load_graphs(graph_path)
self._line_graphs, _ = load_graphs(line_graph_path)
info = load_info(info_path)
self._graphs, _ = load_graphs(self.graph_path)
self._line_graphs, _ = load_graphs(self.line_graph_path)
info = load_info(self.info_path)
self._graph_degrees = info["graph_degree"]
self._line_graph_degrees = info["line_graph_degree"]
self._pm_pds = info["pm_pds"]
......
......@@ -221,27 +221,31 @@ class SSTDataset(DGLBuiltinDataset):
ret = from_networkx(g, node_attrs=["x", "y", "mask"])
return ret
@property
def graph_path(self):
return os.path.join(self.save_path, self.mode + "_dgl_graph.bin")
@property
def vocab_path(self):
return os.path.join(self.save_path, "vocab.pkl")
def has_cache(self):
graph_path = os.path.join(self.save_path, self.mode + "_dgl_graph.bin")
vocab_path = os.path.join(self.save_path, "vocab.pkl")
return os.path.exists(graph_path) and os.path.exists(vocab_path)
return os.path.exists(self.graph_path) and os.path.exists(
self.vocab_path
)
def save(self):
graph_path = os.path.join(self.save_path, self.mode + "_dgl_graph.bin")
save_graphs(graph_path, self._trees)
vocab_path = os.path.join(self.save_path, "vocab.pkl")
save_info(vocab_path, {"vocab": self.vocab})
save_graphs(self.graph_path, self._trees)
save_info(self.vocab_path, {"vocab": self.vocab})
if self.pretrained_emb:
emb_path = os.path.join(self.save_path, "emb.pkl")
save_info(emb_path, {"embed": self.pretrained_emb})
def load(self):
graph_path = os.path.join(self.save_path, self.mode + "_dgl_graph.bin")
vocab_path = os.path.join(self.save_path, "vocab.pkl")
emb_path = os.path.join(self.save_path, "emb.pkl")
self._trees = load_graphs(graph_path)[0]
self._vocab = load_info(vocab_path)["vocab"]
self._trees = load_graphs(self.graph_path)[0]
self._vocab = load_info(self.vocab_path)["vocab"]
self._pretrained_emb = None
if os.path.exists(emb_path):
self._pretrained_emb = load_info(emb_path)["embed"]
......
......@@ -214,43 +214,37 @@ class LegacyTUDataset(DGLBuiltinDataset):
self.graph_labels = F.tensor(self.graph_labels)
def save(self):
graph_path = os.path.join(
self.save_path, "legacy_tu_{}_{}.bin".format(self.name, self.hash)
)
info_path = os.path.join(
self.save_path, "legacy_tu_{}_{}.pkl".format(self.name, self.hash)
)
label_dict = {"labels": self.graph_labels}
info_dict = {
"max_num_node": self.max_num_node,
"num_labels": self.num_labels,
}
save_graphs(str(graph_path), self.graph_lists, label_dict)
save_info(str(info_path), info_dict)
save_graphs(str(self.graph_path), self.graph_lists, label_dict)
save_info(str(self.info_path), info_dict)
def load(self):
graph_path = os.path.join(
self.save_path, "legacy_tu_{}_{}.bin".format(self.name, self.hash)
)
info_path = os.path.join(
self.save_path, "legacy_tu_{}_{}.pkl".format(self.name, self.hash)
)
graphs, label_dict = load_graphs(str(graph_path))
info_dict = load_info(str(info_path))
graphs, label_dict = load_graphs(str(self.graph_path))
info_dict = load_info(str(self.info_path))
self.graph_lists = graphs
self.graph_labels = label_dict["labels"]
self.max_num_node = info_dict["max_num_node"]
self.num_labels = info_dict["num_labels"]
def has_cache(self):
graph_path = os.path.join(
@property
def graph_path(self):
return os.path.join(
self.save_path, "legacy_tu_{}_{}.bin".format(self.name, self.hash)
)
info_path = os.path.join(
@property
def info_path(self):
return os.path.join(
self.save_path, "legacy_tu_{}_{}.pkl".format(self.name, self.hash)
)
if os.path.exists(graph_path) and os.path.exists(info_path):
def has_cache(self):
if os.path.exists(self.graph_path) and os.path.exists(self.info_path):
return True
return False
......@@ -455,22 +449,26 @@ class TUDataset(DGLBuiltinDataset):
self.graph_lists = [g.subgraph(node_idx) for node_idx in node_idx_list]
@property
def graph_path(self):
return os.path.join(self.save_path, "tu_{}.bin".format(self.name))
@property
def info_path(self):
return os.path.join(self.save_path, "tu_{}.pkl".format(self.name))
def save(self):
graph_path = os.path.join(self.save_path, "tu_{}.bin".format(self.name))
info_path = os.path.join(self.save_path, "tu_{}.pkl".format(self.name))
label_dict = {"labels": self.graph_labels}
info_dict = {
"max_num_node": self.max_num_node,
"num_labels": self.num_labels,
}
save_graphs(str(graph_path), self.graph_lists, label_dict)
save_info(str(info_path), info_dict)
save_graphs(str(self.graph_path), self.graph_lists, label_dict)
save_info(str(self.info_path), info_dict)
def load(self):
graph_path = os.path.join(self.save_path, "tu_{}.bin".format(self.name))
info_path = os.path.join(self.save_path, "tu_{}.pkl".format(self.name))
graphs, label_dict = load_graphs(str(graph_path))
info_dict = load_info(str(info_path))
graphs, label_dict = load_graphs(str(self.graph_path))
info_dict = load_info(str(self.info_path))
self.graph_lists = graphs
self.graph_labels = label_dict["labels"]
......@@ -478,9 +476,7 @@ class TUDataset(DGLBuiltinDataset):
self.num_labels = info_dict["num_labels"]
def has_cache(self):
graph_path = os.path.join(self.save_path, "tu_{}.bin".format(self.name))
info_path = os.path.join(self.save_path, "tu_{}.pkl".format(self.name))
if os.path.exists(graph_path) and os.path.exists(info_path):
if os.path.exists(self.graph_path) and os.path.exists(self.info_path):
return True
return False
......
......@@ -90,17 +90,15 @@ class ZINCDataset(DGLBuiltinDataset):
def process(self):
self.load()
@property
def graph_path(self):
return os.path.join(self.save_path, "ZincDGL_{}.bin".format(self.mode))
def has_cache(self):
graph_path = os.path.join(
self.save_path, "ZincDGL_{}.bin".format(self.mode)
)
return os.path.exists(graph_path)
return os.path.exists(self.graph_path)
def load(self):
graph_path = os.path.join(
self.save_path, "ZincDGL_{}.bin".format(self.mode)
)
self._graphs, self._labels = load_graphs(graph_path)
self._graphs, self._labels = load_graphs(self.graph_path)
@property
def num_atom_types(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment