Unverified Commit b3538802 authored by Xiangkun Hu's avatar Xiangkun Hu Committed by GitHub
Browse files

[Bugfix] fix vocab not saved bug in SSTDataset (#2036)

* PPIDataset

* Revert "PPIDataset"

This reverts commit 264bd0c960cfa698a7bb946dad132bf52c2d0c8a.

* Update tree.py
parent 5cb57593
......@@ -195,27 +195,27 @@ class SSTDataset(DGLBuiltinDataset):
def has_cache(self):
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
ret = os.path.exists(graph_path)
if self.mode == 'train':
info_path = os.path.join(self.save_path, 'graph_info.pkl')
ret = ret and os.path.exists(info_path)
return ret
vocab_path = os.path.join(self.save_path, 'vocab.pkl')
return os.path.exists(graph_path) and os.path.exists(vocab_path)
def save(self):
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
save_graphs(graph_path, self._trees)
if self.mode == 'train':
info_path = os.path.join(self.save_path, 'info.pkl')
save_info(info_path, {'vocab': self.vocab, 'embed': self.pretrained_emb})
vocab_path = os.path.join(self.save_path, 'vocab.pkl')
save_info(vocab_path, {'vocab': self.vocab})
if self.pretrained_emb:
emb_path = os.path.join(self.save_path, 'emb.pkl')
save_info(emb_path, {'embed': self.pretrained_emb})
def load(self):
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
vocab_path = os.path.join(self.save_path, 'vocab.pkl')
emb_path = os.path.join(self.save_path, 'emb.pkl')
self._trees = load_graphs(graph_path)[0]
info_path = os.path.join(self.save_path, 'info.pkl')
if os.path.exists(info_path):
info = load_info(info_path)
self._vocab = info['vocab']
self._pretrained_emb = info['embed']
self._vocab = load_info(vocab_path)['vocab']
if os.path.exists(emb_path):
self._pretrained_emb = load_info(emb_path)['embed']
@property
def trees(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment