"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "25a51b63ca75e1351069bee87a0fb3df5abb89c3"
Unverified Commit b3538802 authored by Xiangkun Hu's avatar Xiangkun Hu Committed by GitHub
Browse files

[Bugfix] fix vocab not saved bug in SSTDataset (#2036)

* PPIDataset

* Revert "PPIDataset"

This reverts commit 264bd0c960cfa698a7bb946dad132bf52c2d0c8a.

* Update tree.py
parent 5cb57593
...@@ -195,27 +195,27 @@ class SSTDataset(DGLBuiltinDataset): ...@@ -195,27 +195,27 @@ class SSTDataset(DGLBuiltinDataset):
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin') graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
ret = os.path.exists(graph_path) vocab_path = os.path.join(self.save_path, 'vocab.pkl')
if self.mode == 'train': return os.path.exists(graph_path) and os.path.exists(vocab_path)
info_path = os.path.join(self.save_path, 'graph_info.pkl')
ret = ret and os.path.exists(info_path)
return ret
def save(self): def save(self):
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin') graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
save_graphs(graph_path, self._trees) save_graphs(graph_path, self._trees)
if self.mode == 'train': vocab_path = os.path.join(self.save_path, 'vocab.pkl')
info_path = os.path.join(self.save_path, 'info.pkl') save_info(vocab_path, {'vocab': self.vocab})
save_info(info_path, {'vocab': self.vocab, 'embed': self.pretrained_emb}) if self.pretrained_emb:
emb_path = os.path.join(self.save_path, 'emb.pkl')
save_info(emb_path, {'embed': self.pretrained_emb})
def load(self): def load(self):
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin') graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
vocab_path = os.path.join(self.save_path, 'vocab.pkl')
emb_path = os.path.join(self.save_path, 'emb.pkl')
self._trees = load_graphs(graph_path)[0] self._trees = load_graphs(graph_path)[0]
info_path = os.path.join(self.save_path, 'info.pkl') self._vocab = load_info(vocab_path)['vocab']
if os.path.exists(info_path): if os.path.exists(emb_path):
info = load_info(info_path) self._pretrained_emb = load_info(emb_path)['embed']
self._vocab = info['vocab']
self._pretrained_emb = info['embed']
@property @property
def trees(self): def trees(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment