Unverified Commit df8a7be5 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[KG] save embeddings in NumPy (#900)

* fix loading and saving.

* use numpy.
parent 3bc31098
...@@ -86,9 +86,9 @@ class ExternalEmbedding: ...@@ -86,9 +86,9 @@ class ExternalEmbedding:
return nd.concat(*data, dim=0) return nd.concat(*data, dim=0)
def save(self, path, name): def save(self, path, name):
emb_fname = os.path.join(path, name+'.emb') emb_fname = os.path.join(path, name+'.npy')
nd.save(emb_fname, self.emb) np.save(emb_fname, self.emb.asnumpy())
def load(self, path, name): def load(self, path, name):
emb_fname = os.path.join(path, name+'.emb') emb_fname = os.path.join(path, name+'.npy')
self.emb = nd.load(emb_fname)[0] self.emb = nd.array(np.load(emb_fname))
...@@ -96,7 +96,7 @@ class ExternalEmbedding: ...@@ -96,7 +96,7 @@ class ExternalEmbedding:
return th.cat(data, 0) return th.cat(data, 0)
def save(self, path, name): def save(self, path, name):
file_name = os.path.join(path, name) file_name = os.path.join(path, name+'.npy')
np.save(file_name, self.emb.cpu().detach().numpy()) np.save(file_name, self.emb.cpu().detach().numpy())
def load(self, path, name): def load(self, path, name):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment