"tests/kit/vscode:/vscode.git/clone" did not exist on "285fe7ba7183f16442b570ac3e0f2de0e567d009"
DenoisingAutoEncoderDataset.py 1.7 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
from typing import List
Rayyyyy's avatar
Rayyyyy committed
2

Rayyyyy's avatar
Rayyyyy committed
3
import numpy as np
Rayyyyy's avatar
Rayyyyy committed
4
5
6
7
from torch.utils.data import Dataset
from transformers.utils.import_utils import NLTK_IMPORT_ERROR, is_nltk_available

from sentence_transformers.readers.InputExample import InputExample
Rayyyyy's avatar
Rayyyyy committed
8
9
10
11
12
13
14
15


class DenoisingAutoEncoderDataset(Dataset):
    """
    The DenoisingAutoEncoderDataset returns InputExamples in the format: texts=[noise_fn(sentence), sentence]
    It is used in combination with the DenoisingAutoEncoderLoss: Here, a decoder tries to re-construct the
    sentence without noise.

Rayyyyy's avatar
Rayyyyy committed
16
17
18
19
    Args:
        sentences: A list of sentences
        noise_fn: A noise function: Given a string, it returns a string
            with noise, e.g. deleted words
Rayyyyy's avatar
Rayyyyy committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    """

    def __init__(self, sentences: List[str], noise_fn=lambda s: DenoisingAutoEncoderDataset.delete(s)):
        if not is_nltk_available():
            raise ImportError(NLTK_IMPORT_ERROR.format(self.__class__.__name__))

        self.sentences = sentences
        self.noise_fn = noise_fn

    def __getitem__(self, item):
        sent = self.sentences[item]
        return InputExample(texts=[self.noise_fn(sent), sent])

    def __len__(self):
        return len(self.sentences)

    # Deletion noise.
    @staticmethod
    def delete(text, del_ratio=0.6):
Rayyyyy's avatar
Rayyyyy committed
39
        from nltk import TreebankWordDetokenizer, word_tokenize
Rayyyyy's avatar
Rayyyyy committed
40
41
42
43
44
45
46
47
48
49
50

        words = word_tokenize(text)
        n = len(words)
        if n == 0:
            return text

        keep_or_not = np.random.rand(n) > del_ratio
        if sum(keep_or_not) == 0:
            keep_or_not[np.random.choice(n)] = True  # guarantee that at least one word remains
        words_processed = TreebankWordDetokenizer().detokenize(np.array(words)[keep_or_not])
        return words_processed