vecs.py 3.06 KB
Newer Older
Martin Wicke's avatar
Martin Wicke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import mmap
import numpy as np
import os
18
19
20

from six import string_types

Martin Wicke's avatar
Martin Wicke committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

class Vecs(object):
  def __init__(self, vocab_filename, rows_filename, cols_filename=None):
    """Initializes the vectors from a text vocabulary and binary data."""
    with open(vocab_filename, 'r') as lines:
      self.vocab = [line.split()[0] for line in lines]
      self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)}

    n = len(self.vocab)

    with open(rows_filename, 'r') as rows_fh:
      rows_fh.seek(0, os.SEEK_END)
      size = rows_fh.tell()

      # Make sure that the file size seems reasonable.
      if size % (4 * n) != 0:
        raise IOError(
            'unexpected file size for binary vector file %s' % rows_filename)

      # Memory map the rows.
      dim = size / (4 * n)
      rows_mm = mmap.mmap(rows_fh.fileno(), 0, prot=mmap.PROT_READ)
      rows = np.matrix(
          np.frombuffer(rows_mm, dtype=np.float32).reshape(n, dim))

46
47
      # If column vectors were specified, then open them and add them to the
      # row vectors.
Martin Wicke's avatar
Martin Wicke committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
      if cols_filename:
        with open(cols_filename, 'r') as cols_fh:
          cols_mm = mmap.mmap(cols_fh.fileno(), 0, prot=mmap.PROT_READ)
          cols_fh.seek(0, os.SEEK_END)
          if cols_fh.tell() != size:
            raise IOError('row and column vector files have different sizes')

          cols = np.matrix(
              np.frombuffer(cols_mm, dtype=np.float32).reshape(n, dim))

          rows += cols
          cols_mm.close()

      # Normalize so that dot products are just cosine similarity.
      self.vecs = rows / np.linalg.norm(rows, axis=1).reshape(n, 1)
      rows_mm.close()

  def similarity(self, word1, word2):
    """Computes the similarity of two tokens."""
    idx1 = self.word_to_idx.get(word1)
    idx2 = self.word_to_idx.get(word2)
    if not idx1 or not idx2:
      return None

    return float(self.vecs[idx1] * self.vecs[idx2].transpose())

  def neighbors(self, query):
    """Returns the nearest neighbors to the query (a word or vector)."""
76
    if isinstance(query, string_types):
Martin Wicke's avatar
Martin Wicke committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
      idx = self.word_to_idx.get(query)
      if idx is None:
        return None

      query = self.vecs[idx]

    neighbors = self.vecs * query.transpose()

    return sorted(
      zip(self.vocab, neighbors.flat),
      key=lambda kv: kv[1], reverse=True)

  def lookup(self, word):
    """Returns the embedding for a token, or None if no embedding exists."""
    idx = self.word_to_idx.get(word)
    return None if idx is None else self.vecs[idx]