test_generate_13_grams.py 2.59 KB
Newer Older
researcher2's avatar
researcher2 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
from collections import Counter
import shutil
import glob

from scripts.clean_training_data.janitor import *
from scripts.clean_training_data.generate_13_grams import do_ngrams_in_buckets
from scripts.clean_training_data.archiver import Archive, TextReader


def test_generate_13_grams_1():
    data = """A goose (plural geese) is a bird of any of several waterfowl species in the family Anatidae. 
    This group comprises the genera Anser (the grey geese and white geese) and Branta (the black geese). 
    Some other birds, mostly related to the shelducks, have "goose" as part of their names. 
    More distantly related members of the family Anatidae are swans, most of which are larger 
    than true geese, and ducks, which are smaller. The term "goose" may refer to either a male 
    or female bird, but when paired with "gander", refers specifically to a female one (the latter referring 
    to a male). Young birds before fledging are called goslings. The collective noun for a group of 
    geese on the ground is a gaggle; when in flight, they are called a skein, a team, or a wedge; when 
    flying close together, they are called a plump."""

    data = data + data

    # Simple Generation
    n = 13
    janitor = Janitor()    
    ngrams = word_ngrams(janitor.normalize_string(data), n)
    comparison = list(ngrams)
    comparison_counter = Counter(comparison)
    print(len(comparison))
    # print(comparison)

    # Generating into buckets
    test_working_directory = "test_generate_13_grams"
    output_directory = os.path.join(test_working_directory, "output")        
    try:
        shutil.rmtree(output_directory)
    except FileNotFoundError:
        pass
    os.makedirs(test_working_directory, exist_ok=True)
    archive = Archive(os.path.join(test_working_directory, "test.jsonl.zst"))
    archive.add_data(data)
    archive.commit()
    bucket_count = 4
    do_ngrams_in_buckets(n, test_working_directory, bucket_count)

    # Rebuild from buckets
    rebuilt_ngrams = []

    bucket_file_paths = glob.glob(os.path.join(test_working_directory, "output", f"*.bkt.txt")) 
    for bucket_file_path in bucket_file_paths:
        reader = TextReader(bucket_file_path)
        
        for line in reader.read():
            [ngram, document_id] = line.rsplit(" ", 1)
            rebuilt_ngrams.append(ngram)

    # Compare
    result_counter = Counter(rebuilt_ngrams)
    # print(len(result_counter))
    # print(len(comparison_counter))
    assert(len(result_counter) == len(comparison_counter))
    # print(result_counter)
    # print(comparison_counter)    
    assert(comparison_counter == result_counter)