BoW.py 3.18 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
import json
import logging
Rayyyyy's avatar
Rayyyyy committed
3
4
import os
from typing import Dict, List
Rayyyyy's avatar
Rayyyyy committed
5

Rayyyyy's avatar
Rayyyyy committed
6
7
8
9
import torch
from torch import Tensor, nn

from .tokenizer import WhitespaceTokenizer
Rayyyyy's avatar
Rayyyyy committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

logger = logging.getLogger(__name__)


class BoW(nn.Module):
    """Implements a Bag-of-Words (BoW) model to derive sentence embeddings.

    A weighting can be added to allow the generation of tf-idf vectors. The output vector has the size of the vocab.
    """

    def __init__(
        self,
        vocab: List[str],
        word_weights: Dict[str, float] = {},
        unknown_word_weight: float = 1,
        cumulative_term_frequency: bool = True,
    ):
        super(BoW, self).__init__()
        vocab = list(set(vocab))  # Ensure vocab is unique
        self.config_keys = ["vocab", "word_weights", "unknown_word_weight", "cumulative_term_frequency"]
        self.vocab = vocab
        self.word_weights = word_weights
        self.unknown_word_weight = unknown_word_weight
        self.cumulative_term_frequency = cumulative_term_frequency

        # Maps wordIdx -> word weight
        self.weights = []
        num_unknown_words = 0
        for word in vocab:
            weight = unknown_word_weight
            if word in word_weights:
                weight = word_weights[word]
            elif word.lower() in word_weights:
                weight = word_weights[word.lower()]
            else:
                num_unknown_words += 1
            self.weights.append(weight)

        logger.info(
            "{} out of {} words without a weighting value. Set weight to {}".format(
                num_unknown_words, len(vocab), unknown_word_weight
            )
        )

        self.tokenizer = WhitespaceTokenizer(vocab, stop_words=set(), do_lower_case=False)
        self.sentence_embedding_dimension = len(vocab)

    def forward(self, features: Dict[str, Tensor]):
        # Nothing to do, everything is done in get_sentence_features
        return features

    def tokenize(self, texts: List[str], **kwargs) -> List[int]:
        tokenized = [self.tokenizer.tokenize(text, **kwargs) for text in texts]
        return self.get_sentence_features(tokenized)

    def get_sentence_embedding_dimension(self):
        return self.sentence_embedding_dimension

    def get_sentence_features(self, tokenized_texts: List[List[int]], pad_seq_length: int = 0):
        vectors = []

        for tokens in tokenized_texts:
Rayyyyy's avatar
Rayyyyy committed
72
            vector = torch.zeros(self.get_sentence_embedding_dimension(), dtype=torch.float32)
Rayyyyy's avatar
Rayyyyy committed
73
74
75
76
77
78
79
            for token in tokens:
                if self.cumulative_term_frequency:
                    vector[token] += self.weights[token]
                else:
                    vector[token] = self.weights[token]
            vectors.append(vector)

Rayyyyy's avatar
Rayyyyy committed
80
        return {"sentence_embedding": torch.stack(vectors)}
Rayyyyy's avatar
Rayyyyy committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94

    def get_config_dict(self):
        return {key: self.__dict__[key] for key in self.config_keys}

    def save(self, output_path):
        with open(os.path.join(output_path, "config.json"), "w") as fOut:
            json.dump(self.get_config_dict(), fOut, indent=2)

    @staticmethod
    def load(input_path):
        with open(os.path.join(input_path, "config.json")) as fIn:
            config = json.load(fIn)

        return BoW(**config)