test_generate_13_grams.py 2.72 KB
Newer Older
1
2
import glob
import logging
researcher2's avatar
researcher2 committed
3
4
import os
import shutil
5
from collections import Counter
researcher2's avatar
researcher2 committed
6

7
from lm_eval.decontamination.archiver import Archive, TextReader
8
from lm_eval.decontamination.janitor import Janitor, word_ngrams
researcher2's avatar
researcher2 committed
9
10
from scripts.clean_training_data.generate_13_grams import do_ngrams_in_buckets

Fabrizio Milo's avatar
Fabrizio Milo committed
11

researcher2's avatar
researcher2 committed
12
logger = logging.getLogger(__name__)
researcher2's avatar
researcher2 committed
13

Fabrizio Milo's avatar
Fabrizio Milo committed
14

researcher2's avatar
researcher2 committed
15
def test_generate_13_grams_1(caplog):
Fabrizio Milo's avatar
Fabrizio Milo committed
16
17
18
19
20
21
22
23
    data = """A goose (plural geese) is a bird of any of several waterfowl species in the family Anatidae.
    This group comprises the genera Anser (the grey geese and white geese) and Branta (the black geese).
    Some other birds, mostly related to the shelducks, have "goose" as part of their names.
    More distantly related members of the family Anatidae are swans, most of which are larger
    than true geese, and ducks, which are smaller. The term "goose" may refer to either a male
    or female bird, but when paired with "gander", refers specifically to a female one (the latter referring
    to a male). Young birds before fledging are called goslings. The collective noun for a group of
    geese on the ground is a gaggle; when in flight, they are called a skein, a team, or a wedge; when
researcher2's avatar
researcher2 committed
24
25
26
27
28
    flying close together, they are called a plump."""

    data = data + data

    # Simple Generation
researcher2's avatar
researcher2 committed
29
    print("simple generation")
researcher2's avatar
researcher2 committed
30
    n = 13
Fabrizio Milo's avatar
Fabrizio Milo committed
31
    janitor = Janitor()
researcher2's avatar
researcher2 committed
32
33
34
35
36
37
38
    ngrams = word_ngrams(janitor.normalize_string(data), n)
    comparison = list(ngrams)
    comparison_counter = Counter(comparison)
    print(len(comparison))
    # print(comparison)

    # Generating into buckets
researcher2's avatar
researcher2 committed
39
    print("bucket generation")
researcher2's avatar
researcher2 committed
40
41
    test_working_directory = "test_generate_13_grams"
    try:
researcher2's avatar
researcher2 committed
42
        shutil.rmtree(test_working_directory)
researcher2's avatar
researcher2 committed
43
44
    except FileNotFoundError:
        pass
researcher2's avatar
researcher2 committed
45
46
    os.makedirs(test_working_directory)

47
48
49
    assert not os.path.exists("../pile")
    os.makedirs("../pile")
    archive = Archive(os.path.join("../pile", "test.jsonl.zst"))
researcher2's avatar
researcher2 committed
50
51
    archive.add_data(data)
    archive.commit()
researcher2's avatar
researcher2 committed
52

researcher2's avatar
researcher2 committed
53
54
55
56
    bucket_count = 4
    do_ngrams_in_buckets(n, test_working_directory, bucket_count)

    # Rebuild from buckets
researcher2's avatar
researcher2 committed
57
    print("rebuild")
researcher2's avatar
researcher2 committed
58
    rebuilt_ngrams = []
Fabrizio Milo's avatar
Fabrizio Milo committed
59
    bucket_file_paths = glob.glob(
60
        os.path.join(test_working_directory, "output", "*.bkt.txt")
Fabrizio Milo's avatar
Fabrizio Milo committed
61
    )
researcher2's avatar
researcher2 committed
62
63
    for bucket_file_path in bucket_file_paths:
        reader = TextReader(bucket_file_path)
Fabrizio Milo's avatar
Fabrizio Milo committed
64

researcher2's avatar
researcher2 committed
65
66
67
68
69
        for line in reader.read():
            [ngram, document_id] = line.rsplit(" ", 1)
            rebuilt_ngrams.append(ngram)

    # Compare
Fabrizio Milo's avatar
Fabrizio Milo committed
70
    print("compare")
researcher2's avatar
researcher2 committed
71
72
73
    result_counter = Counter(rebuilt_ngrams)
    # print(len(result_counter))
    # print(len(comparison_counter))
Fabrizio Milo's avatar
Fabrizio Milo committed
74
    assert len(result_counter) == len(comparison_counter)
researcher2's avatar
researcher2 committed
75
    # print(result_counter)
Fabrizio Milo's avatar
Fabrizio Milo committed
76
77
    # print(comparison_counter)
    assert comparison_counter == result_counter