"docs/source/vscode:/vscode.git/clone" did not exist on "06e9ebebd51c3db779dedec5556251c8ecc3a00a"
Unverified Commit df8a7be5 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[KG] save embeddings in NumPy (#900)

* fix loading and saving.

* use numpy.
parent 3bc31098
......@@ -86,9 +86,9 @@ class ExternalEmbedding:
return nd.concat(*data, dim=0)
def save(self, path, name):
emb_fname = os.path.join(path, name+'.emb')
nd.save(emb_fname, self.emb)
emb_fname = os.path.join(path, name+'.npy')
np.save(emb_fname, self.emb.asnumpy())
def load(self, path, name):
emb_fname = os.path.join(path, name+'.emb')
self.emb = nd.load(emb_fname)[0]
emb_fname = os.path.join(path, name+'.npy')
self.emb = nd.array(np.load(emb_fname))
......@@ -96,7 +96,7 @@ class ExternalEmbedding:
return th.cat(data, 0)
def save(self, path, name):
file_name = os.path.join(path, name)
file_name = os.path.join(path, name+'.npy')
np.save(file_name, self.emb.cpu().detach().numpy())
def load(self, path, name):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment