test_generate_13_grams.py 2.71 KB
Newer Older
researcher2's avatar
researcher2 committed
1
2
3
4
5
import os
from collections import Counter
import shutil
import glob

bzantium's avatar
bzantium committed
6
from lm_eval.decontamination.janitor import Janitor, word_ngrams
researcher2's avatar
researcher2 committed
7
from scripts.clean_training_data.generate_13_grams import do_ngrams_in_buckets
bzantium's avatar
bzantium committed
8
from lm_eval.decontamination.archiver import Archive, TextReader
researcher2's avatar
researcher2 committed
9

bzantium's avatar
bzantium committed
10
import logging
researcher2's avatar
researcher2 committed
11

bzantium's avatar
bzantium committed
12
13
14
15
16
17
18
19
20
21
22
23
logger = logging.getLogger(__name__)


def test_generate_13_grams_1(caplog):
    data = """A goose (plural geese) is a bird of any of several waterfowl species in the family Anatidae.
    This group comprises the genera Anser (the grey geese and white geese) and Branta (the black geese).
    Some other birds, mostly related to the shelducks, have "goose" as part of their names.
    More distantly related members of the family Anatidae are swans, most of which are larger
    than true geese, and ducks, which are smaller. The term "goose" may refer to either a male
    or female bird, but when paired with "gander", refers specifically to a female one (the latter referring
    to a male). Young birds before fledging are called goslings. The collective noun for a group of
    geese on the ground is a gaggle; when in flight, they are called a skein, a team, or a wedge; when
researcher2's avatar
researcher2 committed
24
25
26
27
28
    flying close together, they are called a plump."""

    data = data + data

    # Simple Generation
bzantium's avatar
bzantium committed
29
    print("simple generation")
researcher2's avatar
researcher2 committed
30
    n = 13
bzantium's avatar
bzantium committed
31
    janitor = Janitor()
researcher2's avatar
researcher2 committed
32
33
34
35
36
37
38
    ngrams = word_ngrams(janitor.normalize_string(data), n)
    comparison = list(ngrams)
    comparison_counter = Counter(comparison)
    print(len(comparison))
    # print(comparison)

    # Generating into buckets
bzantium's avatar
bzantium committed
39
    print("bucket generation")
researcher2's avatar
researcher2 committed
40
41
    test_working_directory = "test_generate_13_grams"
    try:
bzantium's avatar
bzantium committed
42
        shutil.rmtree(test_working_directory)
researcher2's avatar
researcher2 committed
43
44
    except FileNotFoundError:
        pass
bzantium's avatar
bzantium committed
45
46
47
48
49
    os.makedirs(test_working_directory)

    assert not os.path.exists("pile")
    os.makedirs("pile")
    archive = Archive(os.path.join("pile", "test.jsonl.zst"))
researcher2's avatar
researcher2 committed
50
51
    archive.add_data(data)
    archive.commit()
bzantium's avatar
bzantium committed
52

researcher2's avatar
researcher2 committed
53
54
55
56
    bucket_count = 4
    do_ngrams_in_buckets(n, test_working_directory, bucket_count)

    # Rebuild from buckets
bzantium's avatar
bzantium committed
57
    print("rebuild")
researcher2's avatar
researcher2 committed
58
    rebuilt_ngrams = []
bzantium's avatar
bzantium committed
59
60
61
    bucket_file_paths = glob.glob(
        os.path.join(test_working_directory, "output", f"*.bkt.txt")
    )
researcher2's avatar
researcher2 committed
62
63
    for bucket_file_path in bucket_file_paths:
        reader = TextReader(bucket_file_path)
bzantium's avatar
bzantium committed
64

researcher2's avatar
researcher2 committed
65
66
67
68
69
        for line in reader.read():
            [ngram, document_id] = line.rsplit(" ", 1)
            rebuilt_ngrams.append(ngram)

    # Compare
bzantium's avatar
bzantium committed
70
    print("compare")
researcher2's avatar
researcher2 committed
71
72
73
    result_counter = Counter(rebuilt_ngrams)
    # print(len(result_counter))
    # print(len(comparison_counter))
bzantium's avatar
bzantium committed
74
    assert len(result_counter) == len(comparison_counter)
researcher2's avatar
researcher2 committed
75
    # print(result_counter)
bzantium's avatar
bzantium committed
76
77
    # print(comparison_counter)
    assert comparison_counter == result_counter