test_generate_13_grams.py 2.7 KB
Newer Older
researcher2's avatar
researcher2 committed
1
2
3
4
5
import os
from collections import Counter
import shutil
import glob

researcher2's avatar
researcher2 committed
6
from lm_eval.decontamination.janitor import *
researcher2's avatar
researcher2 committed
7
from scripts.clean_training_data.generate_13_grams import do_ngrams_in_buckets
researcher2's avatar
researcher2 committed
8
from lm_eval.decontamination.archiver import Archive, TextReader
researcher2's avatar
researcher2 committed
9

researcher2's avatar
researcher2 committed
10
11
import logging
logger = logging.getLogger(__name__)
researcher2's avatar
researcher2 committed
12

researcher2's avatar
researcher2 committed
13
def test_generate_13_grams_1(caplog):
researcher2's avatar
researcher2 committed
14
15
16
17
18
19
20
21
22
23
24
25
26
    data = """A goose (plural geese) is a bird of any of several waterfowl species in the family Anatidae. 
    This group comprises the genera Anser (the grey geese and white geese) and Branta (the black geese). 
    Some other birds, mostly related to the shelducks, have "goose" as part of their names. 
    More distantly related members of the family Anatidae are swans, most of which are larger 
    than true geese, and ducks, which are smaller. The term "goose" may refer to either a male 
    or female bird, but when paired with "gander", refers specifically to a female one (the latter referring 
    to a male). Young birds before fledging are called goslings. The collective noun for a group of 
    geese on the ground is a gaggle; when in flight, they are called a skein, a team, or a wedge; when 
    flying close together, they are called a plump."""

    data = data + data

    # Simple Generation
researcher2's avatar
researcher2 committed
27
    print("simple generation")
researcher2's avatar
researcher2 committed
28
29
30
31
32
33
34
35
36
    n = 13
    janitor = Janitor()    
    ngrams = word_ngrams(janitor.normalize_string(data), n)
    comparison = list(ngrams)
    comparison_counter = Counter(comparison)
    print(len(comparison))
    # print(comparison)

    # Generating into buckets
researcher2's avatar
researcher2 committed
37
    print("bucket generation")
researcher2's avatar
researcher2 committed
38
39
    test_working_directory = "test_generate_13_grams"
    try:
researcher2's avatar
researcher2 committed
40
        shutil.rmtree(test_working_directory)
researcher2's avatar
researcher2 committed
41
42
    except FileNotFoundError:
        pass
researcher2's avatar
researcher2 committed
43
44
45
46
47
    os.makedirs(test_working_directory)

    assert(not os.path.exists("pile"))
    os.makedirs("pile")
    archive = Archive(os.path.join("pile", "test.jsonl.zst"))
researcher2's avatar
researcher2 committed
48
49
    archive.add_data(data)
    archive.commit()
researcher2's avatar
researcher2 committed
50

researcher2's avatar
researcher2 committed
51
52
53
54
    bucket_count = 4
    do_ngrams_in_buckets(n, test_working_directory, bucket_count)

    # Rebuild from buckets
researcher2's avatar
researcher2 committed
55
    print("rebuild")
researcher2's avatar
researcher2 committed
56
57
58
59
60
61
62
63
64
65
    rebuilt_ngrams = []
    bucket_file_paths = glob.glob(os.path.join(test_working_directory, "output", f"*.bkt.txt")) 
    for bucket_file_path in bucket_file_paths:
        reader = TextReader(bucket_file_path)
        
        for line in reader.read():
            [ngram, document_id] = line.rsplit(" ", 1)
            rebuilt_ngrams.append(ngram)

    # Compare
researcher2's avatar
researcher2 committed
66
    print("compare")    
researcher2's avatar
researcher2 committed
67
68
69
70
71
72
73
    result_counter = Counter(rebuilt_ngrams)
    # print(len(result_counter))
    # print(len(comparison_counter))
    assert(len(result_counter) == len(comparison_counter))
    # print(result_counter)
    # print(comparison_counter)    
    assert(comparison_counter == result_counter)