Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
f495bfb4
Commit
f495bfb4
authored
Mar 03, 2022
by
researcher2
Browse files
Refactor for PR
parent
0f283a9c
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
407 additions
and
68 deletions
+407
-68
lm_eval/decontamination/__init__.py
lm_eval/decontamination/__init__.py
+0
-0
lm_eval/decontamination/archiver.py
lm_eval/decontamination/archiver.py
+8
-4
lm_eval/decontamination/decontaminate.py
lm_eval/decontamination/decontaminate.py
+153
-0
lm_eval/decontamination/janitor.py
lm_eval/decontamination/janitor.py
+0
-1
lm_eval/evaluator.py
lm_eval/evaluator.py
+2
-1
main.py
main.py
+2
-20
pile_statistics.json
pile_statistics.json
+37
-0
scripts/clean_training_data/generate_13_grams.py
scripts/clean_training_data/generate_13_grams.py
+109
-33
scripts/clean_training_data/investigate_pile.py
scripts/clean_training_data/investigate_pile.py
+79
-0
tests/test_generate_13_grams.py
tests/test_generate_13_grams.py
+16
-8
tests/test_janitor.py
tests/test_janitor.py
+1
-1
No files found.
lm_eval/decontamination/__init__.py
0 → 100644
View file @
f495bfb4
scripts/clean_training_data
/archiver.py
→
lm_eval/decontamination
/archiver.py
View file @
f495bfb4
...
@@ -6,6 +6,7 @@ import io
...
@@ -6,6 +6,7 @@ import io
import
datetime
import
datetime
import
mmap
import
mmap
import
tqdm
import
tqdm
from
pathlib
import
Path
def
json_serial
(
obj
):
def
json_serial
(
obj
):
"""JSON serializer for objects not serializable by default json code"""
"""JSON serializer for objects not serializable by default json code"""
...
@@ -61,16 +62,19 @@ class Reader:
...
@@ -61,16 +62,19 @@ class Reader:
else
:
else
:
yield
text
yield
text
# Simple text reader and writer with same interface as above
class
TextArchive
:
class
TextArchive
:
def
__init__
(
self
,
file_path
,
mode
=
"
ab
"
):
def
__init__
(
self
,
file_path
,
mode
=
"
rb+
"
):
self
.
file_path
=
file_path
self
.
file_path
=
file_path
dir_name
=
os
.
path
.
dirname
(
file_path
)
dir_name
=
os
.
path
.
dirname
(
file_path
)
if
dir_name
:
if
dir_name
:
os
.
makedirs
(
dir_name
,
exist_ok
=
True
)
os
.
makedirs
(
dir_name
,
exist_ok
=
True
)
self
.
fh
=
open
(
self
.
file_path
,
mode
)
if
not
os
.
path
.
exists
(
file_path
):
Path
(
file_path
).
touch
()
self
.
fh
=
open
(
self
.
file_path
,
mode
)
def
add_data
(
self
,
data
,
meta
=
{}
):
def
add_data
(
self
,
data
):
self
.
fh
.
write
(
data
.
encode
(
'UTF-8'
)
+
b
'
\n
'
)
self
.
fh
.
write
(
data
.
encode
(
'UTF-8'
)
+
b
'
\n
'
)
def
commit
(
self
):
def
commit
(
self
):
...
...
lm_eval/decontamination.py
→
lm_eval/decontamination
/decontaminate
.py
View file @
f495bfb4
import
mmap
import
string
import
time
import
time
import
random
import
random
import
pickle
import
pickle
import
json
import
glob
import
glob
import
os
import
os
import
collections
import
collections
import
re
import
tqdm
from
.janitor
import
Janitor
,
word_ngrams
from
.archiver
import
ZStdTextReader
try
:
import
janitor_util
JANITOR_CPP
=
True
except
Exception
as
e
:
# print("WARNING: C++ module could not be loaded. Janitor running in python mode")
JANITOR_CPP
=
False
# Was used for testing the evaluator decoupled from the full logic below
# Was used for testing the evaluator decoupled from the full logic below
def
get_train_overlap_stub
(
docs
,
ngrams_path
,
ngrams_n_size
):
def
get_train_overlap_stub
(
docs
,
ngrams_path
,
ngrams_n_size
):
...
@@ -25,11 +17,12 @@ def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
...
@@ -25,11 +17,12 @@ def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
# Returns a dictionary containing all overlapping documents in each
# Returns a dictionary containing all overlapping documents in each
# task based on any 13gram being found in the training set.
# task. In the standard use case, an overlap occurs when any of the 13-grams
# ngrams_path is the parent directory containing the "ngrams_{x}.bkt.txt.sorted.zst"
# found in the task document exist in the training set documents.
# files built by the other scripts "generate_13_grams.py" and "sort_13_gram_buckets.py.
#
# ngrams_n_size is expected to be 13 but we made it a parameter for generality.
# To generate 13-grams for the pile see scripts/clean_training_data. The final output of these
# The task set is only included for caching purposes.
# scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst"
# files. These should exist in the "ngrams_path" provided to this function.
# Algorithm:
# Algorithm:
# 1. Build lookups for each dataset {ngram: list(document_ids)}
# 1. Build lookups for each dataset {ngram: list(document_ids)}
...
@@ -39,12 +32,12 @@ def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
...
@@ -39,12 +32,12 @@ def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
# 4. Strip the task_set from the dictionary keys and return
# 4. Strip the task_set from the dictionary keys and return
#
#
# We cache the task+set lookups as well as the overlaps.
# We cache the task+set lookups as well as the overlaps.
#
# Currently calculating some per file ngram stats for interest, might remove before merging into main
def
get_train_overlap
(
docs_by_task_set
,
ngrams_path
,
limit
):
def
get_train_overlap
(
docs_by_task_set
,
ngrams_path
,
limit
):
# return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)
# return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)
# TODO: infer ngrams_n_size from ngrams_path
info_dict_path
=
os
.
path
.
join
(
ngrams_path
,
"info.json"
)
info_dict
=
json
.
load
(
open
(
info_dict_path
,
"r"
))
ngrams_n_size
=
info_dict
[
"ngram_size"
]
janitor
=
Janitor
()
janitor
=
Janitor
()
...
@@ -62,16 +55,17 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
...
@@ -62,16 +55,17 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
for
(
task_name
,
task_set
),
docs
in
docs_by_task_set
.
items
():
for
(
task_name
,
task_set
),
docs
in
docs_by_task_set
.
items
():
if
not
os
.
path
.
exists
(
f
"data/
{
task_name
}
"
):
if
not
os
.
path
.
exists
(
f
"data/
{
task_name
}
"
):
os
.
mkdir
(
f
"data/
{
task_name
}
"
)
os
.
mkdir
(
f
"data/
{
task_name
}
"
)
# Check if we've decontaminated this set before
# Check if we've decontaminated this combination before
overlaps_dump_path
=
get_overlaps_dump_path
(
task_name
,
task_set
,
ngrams_n_size
,
limit
)
overlaps_dump_path
=
get_overlaps_dump_path
(
task_name
,
task_set
,
ngrams_n_size
,
limit
)
if
os
.
path
.
exists
(
overlaps_dump_path
):
if
os
.
path
.
exists
(
overlaps_dump_path
):
duplicates
[(
task_name
,
task_set
)]
=
pickle
.
load
(
open
(
overlaps_dump_path
,
"rb"
))
duplicates
[(
task_name
,
task_set
)]
=
pickle
.
load
(
open
(
overlaps_dump_path
,
"rb"
))
sets_to_decontaminate
-=
1
sets_to_decontaminate
-=
1
continue
continue
else
:
else
:
duplicates
[(
task_name
,
task_set
)]
=
set
()
# No defaultdict, we want to dump empty sets too later
duplicates
[(
task_name
,
task_set
)]
=
set
()
# Build/load the task lookup {ngram: documents}.
# Build/load the task lookup {ngram:
set(
documents
)
}.
task_set_lookup_path
=
f
"data/
{
task_name
}
/
{
task_set
}
_
{
ngrams_n_size
}
grams_limit
{
limit
}
.lookup"
task_set_lookup_path
=
f
"data/
{
task_name
}
/
{
task_set
}
_
{
ngrams_n_size
}
grams_limit
{
limit
}
.lookup"
if
os
.
path
.
exists
(
task_set_lookup_path
):
if
os
.
path
.
exists
(
task_set_lookup_path
):
print
(
f
"
{
task_set_lookup_path
}
available, loading..."
)
print
(
f
"
{
task_set_lookup_path
}
available, loading..."
)
...
@@ -104,7 +98,7 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
...
@@ -104,7 +98,7 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
elapsed
=
time
.
perf_counter
()
-
start
elapsed
=
time
.
perf_counter
()
-
start
print
(
f
"Merging lookups took
{
elapsed
:
0.5
f
}
seconds."
)
print
(
f
"Merging lookups took
{
elapsed
:
0.5
f
}
seconds."
)
print
(
f
"
13
grams files found in
{
ngrams_path
}
:"
)
print
(
f
"
{
ngrams_n_size
}
grams files found in
{
ngrams_path
}
:"
)
files
=
glob
.
glob
(
os
.
path
.
join
(
ngrams_path
,
f
"*.sorted.zst"
))
files
=
glob
.
glob
(
os
.
path
.
join
(
ngrams_path
,
f
"*.sorted.zst"
))
print
(
files
)
print
(
files
)
...
@@ -157,281 +151,3 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
...
@@ -157,281 +151,3 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
# Strip task set and return
# Strip task set and return
return
{
task_name
:
doc_ids
for
(
task_name
,
task_set
),
doc_ids
in
duplicates
.
items
()}
return
{
task_name
:
doc_ids
for
(
task_name
,
task_set
),
doc_ids
in
duplicates
.
items
()}
# Implementation from nltk source
# https://www.nltk.org/_modules/nltk/util.html
def
form_ngrams
(
sequence
,
n
):
history
=
[]
while
n
>
1
:
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
try
:
next_item
=
next
(
sequence
)
except
StopIteration
:
# no more data, terminate the generator
return
history
.
append
(
next_item
)
n
-=
1
for
item
in
sequence
:
history
.
append
(
item
)
yield
tuple
(
history
)
del
history
[
0
]
def
word_ngrams
(
s
,
n
):
"""Splits a string into ngram words"""
tokens
=
s
.
split
()
# not a generator :(
ngram_seqs
=
form_ngrams
(
iter
(
tokens
),
n
)
return
(
" "
.
join
(
ngram
)
for
ngram
in
ngram_seqs
)
# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
def
split_indices
(
s
):
"""Splits a string on whitespaces and records the indices of each in the original string.
@:return generator((word, (start_idx, end_idx)), ...)
"""
return
((
m
.
group
(
0
),
(
m
.
start
(),
m
.
end
()
-
1
))
for
m
in
re
.
finditer
(
r
'\S+'
,
s
))
def
word_ngrams_indices
(
s
,
n
):
"""Splits a string into pairs of (ngram words, their start/end indices)"""
tokens_with_indices
=
split_indices
(
s
)
# Generator of ngrams of (word, idx_pairs)
# (
# [(word, (start,end)), (word, (start, end))...],
# [(word, (start, end)), ...],
# ...
# )
ngram_seqs_with_indices
=
form_ngrams
(
tokens_with_indices
,
n
)
# Generator of pairs of word and index ngrams
# (
# ([word, word, ...], [(start,end), (start,end), ...]),
# ...
# )
ngram_indices_pairs
=
(
zip
(
*
ngram_with_indices
)
for
ngram_with_indices
in
ngram_seqs_with_indices
)
# Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
return
((
" "
.
join
(
ngram_seq
),
(
indices
[
0
][
0
],
indices
[
-
1
][
1
]))
for
ngram_seq
,
indices
in
ngram_indices_pairs
)
class
Janitor
:
# FIXME delete_chars: Should anything else go here? Special chars?
def
__init__
(
self
,
ngram_n
=
13
,
window_to_remove
=
200
,
too_dirty_cutoff
=
10
,
minimum_slice_length
=
200
,
delete_chars
=
string
.
punctuation
):
self
.
ngram_n
=
ngram_n
self
.
window_to_remove
=
window_to_remove
self
.
too_dirty_cutoff
=
too_dirty_cutoff
self
.
minimum_slice_length
=
minimum_slice_length
self
.
delete_chars
=
delete_chars
self
.
dirt_ngrams
=
set
()
# If in python, we'll translate uppercase to lowercase and delete naughty characters.
# This is fast by python standards
# https://stackoverflow.com/questions/638893/what-is-the-most-efficient-way-in-python-to-convert-a-string-to-all-lowercase-st
self
.
translation_table
=
str
.
maketrans
(
string
.
ascii_lowercase
+
string
.
ascii_uppercase
,
# These characters
string
.
ascii_lowercase
*
2
,
# Become these characters
self
.
delete_chars
# These are deleted
)
##############
# I/O for saving contamination ngrams
##############
def
save_contamination_ngrams
(
self
,
filename
):
with
open
(
filename
,
'wb'
)
as
fp
:
pickle
.
dump
(
filename
,
fp
)
def
load_contamination_ngrams
(
self
,
filename
):
with
open
(
filename
,
'rb'
)
as
fp
:
self
.
dirt_ngrams
=
pickle
.
load
(
fp
)
##############
# Call these :)
##############
def
register_contaminant
(
self
,
dirt_string
):
"""Register a string as contamination to be removed, e.g. a test set
This breaks the dirt_string into ngrams to store for future cleaning"""
if
JANITOR_CPP
:
return
self
.
register_contaminant_cpp
(
dirt_string
)
else
:
print
(
"WARNING: Janitor running in python mode"
)
return
self
.
register_contaminant_python
(
dirt_string
)
def
clean
(
self
,
dirty_string
):
"""Clean a string (e.g. a training set) by removing all ngrams previously
reigstered as contaminants. Returns a list of clean chunks, or empty if
the string was too dirty"""
if
JANITOR_CPP
:
return
self
.
clean_cpp
(
dirty_string
)
else
:
print
(
"WARNING: Janitor running in python mode"
)
return
self
.
clean_python
(
dirty_string
)
def
_split_chunks
(
self
,
dirty_string
,
dirty_parts
):
clean_chunks
=
[]
splice_idx
=
0
end
=
-
1
for
i
,
(
ngram
,
start
,
end
)
in
enumerate
(
dirty_parts
):
if
i
>=
self
.
too_dirty_cutoff
:
return
[]
start
=
max
(
0
,
start
-
self
.
window_to_remove
)
end
=
min
(
len
(
dirty_string
),
end
+
self
.
window_to_remove
)
if
start
-
splice_idx
>
self
.
minimum_slice_length
:
clean_chunks
.
append
(
dirty_string
[
splice_idx
:
start
])
splice_idx
=
end
if
end
<
len
(
dirty_string
)
-
self
.
minimum_slice_length
:
clean_chunks
.
append
(
dirty_string
[
end
+
1
:])
return
clean_chunks
##############
# Fast C++
##############
def
register_contaminant_cpp
(
self
,
dirt_string
):
self
.
dirt_ngrams
.
update
(
janitor_util
.
clean_ngram
(
dirt_string
,
self
.
delete_chars
,
self
.
ngram_n
))
def
clean_cpp
(
self
,
dirty_string
):
contamination_indices
=
janitor_util
.
clean_ngram_with_indices
(
dirty_string
,
self
.
delete_chars
,
self
.
ngram_n
)
return
self
.
_split_chunks
(
dirty_string
,
contamination_indices
)
##############
# Slow python
##############
def
normalize_string
(
self
,
s
):
return
s
.
translate
(
self
.
translation_table
)
def
register_contaminant_python
(
self
,
dirt_string
):
self
.
dirt_ngrams
.
update
(
word_ngrams
(
self
.
normalize_string
(
dirt_string
),
self
.
ngram_n
))
def
clean_python
(
self
,
dirty_string
):
contamination_indices
=
(
(
None
,
*
idx_pair
)
for
dirty_ngram
,
idx_pair
in
word_ngrams_indices
(
dirty_string
,
self
.
ngram_n
)
if
self
.
normalize_string
(
dirty_ngram
)
in
self
.
dirt_ngrams
)
return
self
.
_split_chunks
(
dirty_string
,
contamination_indices
)
# Implementation from nltk source
# https://www.nltk.org/_modules/nltk/util.html
def
form_ngrams
(
sequence
,
n
):
history
=
[]
while
n
>
1
:
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
try
:
next_item
=
next
(
sequence
)
except
StopIteration
:
# no more data, terminate the generator
return
history
.
append
(
next_item
)
n
-=
1
for
item
in
sequence
:
history
.
append
(
item
)
yield
tuple
(
history
)
del
history
[
0
]
def
word_ngrams
(
s
,
n
):
"""Splits a string into ngram words"""
tokens
=
s
.
split
()
# not a generator :(
ngram_seqs
=
form_ngrams
(
iter
(
tokens
),
n
)
return
(
" "
.
join
(
ngram
)
for
ngram
in
ngram_seqs
)
# Simple text reader and writer with same interface as above
class
TextArchive
:
def
__init__
(
self
,
file_path
,
mode
=
"ab"
):
self
.
file_path
=
file_path
dir_name
=
os
.
path
.
dirname
(
file_path
)
if
dir_name
:
os
.
makedirs
(
dir_name
,
exist_ok
=
True
)
self
.
fh
=
open
(
self
.
file_path
,
mode
)
def
add_data
(
self
,
data
,
meta
=
{}):
self
.
fh
.
write
(
data
.
encode
(
'UTF-8'
)
+
b
'
\n
'
)
def
commit
(
self
):
self
.
fh
.
flush
()
self
.
fh
.
close
()
class
TextReader
:
def
__init__
(
self
,
file_path
):
self
.
file_path
=
file_path
# Optimized mmap read with infrequent tqdm updates to maintain speed
# Tested up to 250MB/s.
def
read_tqdm
(
self
,
update_frequency
=
10000
):
current_file_position
=
0
line_counter
=
0
with
open
(
self
.
file_path
,
'r'
)
as
fh
,
\
tqdm
.
tqdm
(
total
=
os
.
path
.
getsize
(
self
.
file_path
),
dynamic_ncols
=
True
,
unit
=
"byte"
,
unit_scale
=
1
)
as
progress
:
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
line
=
line
.
decode
(
"utf-8"
)
line_counter
+=
1
if
line_counter
==
update_frequency
:
new_file_pos
=
mmap_obj
.
tell
()
bytes_read
=
new_file_pos
-
current_file_position
current_file_position
=
new_file_pos
progress
.
update
(
bytes_read
)
line_counter
=
0
yield
line
[:
-
1
]
def
read_and_tell
(
self
):
current_file_position
=
0
with
open
(
self
.
file_path
,
'r'
,
encoding
=
"utf8"
)
as
fh
:
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
line
=
line
.
decode
(
"utf-8"
)
new_file_pos
=
mmap_obj
.
tell
()
raw_bytes_read
=
new_file_pos
-
current_file_position
current_file_position
=
new_file_pos
yield
line
[:
-
1
],
raw_bytes_read
def
read
(
self
):
with
open
(
self
.
file_path
,
'r'
,
encoding
=
"utf8"
)
as
fh
:
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
line
=
line
.
decode
(
"utf-8"
)
yield
line
[:
-
1
]
def
read_slow
(
self
):
with
open
(
self
.
file_path
,
'r'
,
encoding
=
"utf8"
)
as
fh
:
while
True
:
line
=
fh
.
readline
()
if
line
==
-
1
or
line
==
""
:
break
else
:
yield
line
[:
-
1
]
# Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader.
class
ZStdTextReader
:
def
__init__
(
self
,
file
):
self
.
file
=
file
def
read_tqdm
(
self
):
decompressed_file
=
self
.
file
[:
-
4
]
print
(
"Decompressing file, please wait..."
)
os
.
system
(
f
"zstd -d
{
self
.
file
}
"
)
# linux decompress is faster
reader
=
TextReader
(
decompressed_file
)
yield
from
reader
.
read_tqdm
()
os
.
remove
(
decompressed_file
)
scripts/clean_training_data
/janitor.py
→
lm_eval/decontamination
/janitor.py
View file @
f495bfb4
...
@@ -4,7 +4,6 @@ import timeit
...
@@ -4,7 +4,6 @@ import timeit
import
pickle
import
pickle
import
traceback
import
traceback
from
pprint
import
pprint
from
pprint
import
pprint
from
lm_eval.decontamination
import
word_ngrams
# This is a cpp module. Compile janitor_util.cpp with:
# This is a cpp module. Compile janitor_util.cpp with:
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
...
...
lm_eval/evaluator.py
View file @
f495bfb4
...
@@ -8,6 +8,7 @@ import lm_eval.base
...
@@ -8,6 +8,7 @@ import lm_eval.base
import
lm_eval.decontamination
import
lm_eval.decontamination
import
numpy
as
np
import
numpy
as
np
from
lm_eval.utils
import
positional_deprecated
from
lm_eval.utils
import
positional_deprecated
from
lm_eval.decontamination.decontaminate
import
get_train_overlap
@
positional_deprecated
@
positional_deprecated
def
simple_evaluate
(
model
,
model_args
=
None
,
tasks
=
[],
def
simple_evaluate
(
model
,
model_args
=
None
,
tasks
=
[],
...
@@ -189,7 +190,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
...
@@ -189,7 +190,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# Compare all tasks/sets at once to ensure a single training set scan
# Compare all tasks/sets at once to ensure a single training set scan
if
decontaminate
:
if
decontaminate
:
print
(
"Finding train/test overlap, please wait..."
)
print
(
"Finding train/test overlap, please wait..."
)
overlaps
=
lm_eval
.
decontamination
.
get_train_overlap
(
docs_for_decontamination
,
decontamination_ngrams_path
,
limit
)
overlaps
=
get_train_overlap
(
docs_for_decontamination
,
decontamination_ngrams_path
,
limit
)
# all responses for each (task, doc)
# all responses for each (task, doc)
process_res_queue
=
collections
.
defaultdict
(
list
)
process_res_queue
=
collections
.
defaultdict
(
list
)
...
...
main.py
View file @
f495bfb4
...
@@ -35,25 +35,11 @@ def parse_args():
...
@@ -35,25 +35,11 @@ def parse_args():
parser
.
add_argument
(
'--output_path'
,
default
=
None
)
parser
.
add_argument
(
'--output_path'
,
default
=
None
)
parser
.
add_argument
(
'--limit'
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
'--limit'
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
'--no_cache'
,
action
=
"store_true"
)
parser
.
add_argument
(
'--no_cache'
,
action
=
"store_true"
)
parser
.
add_argument
(
'--decontaminate'
,
action
=
"store_true"
)
parser
.
add_argument
(
'--decontamination_ngrams_path'
,
default
=
None
)
parser
.
add_argument
(
'--decontaminate_ngrams_path'
,
default
=
None
)
parser
.
add_argument
(
'--decontaminate_ngrams_n_size'
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
'--description_dict_path'
,
default
=
None
)
parser
.
add_argument
(
'--description_dict_path'
,
default
=
None
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
def
ensure_correct_decontamination_params
(
args
):
valid
=
True
if
args
.
decontaminate
:
if
not
args
.
decontaminate_ngrams_n_size
:
print
(
"Please specify n size of training set n-grams. (--ngrams_n_size)"
)
valid
=
False
if
not
args
.
decontaminate_ngrams_path
:
print
(
"Please specify path containing training set n-grams. (--ngrams_path)"
)
valid
=
False
return
valid
# Returns a list containing all values of the source_list that
# Returns a list containing all values of the source_list that
# match at least one of the patterns
# match at least one of the patterns
def
pattern_match
(
patterns
,
source_list
):
def
pattern_match
(
patterns
,
source_list
):
...
@@ -65,8 +51,6 @@ def pattern_match(patterns, source_list):
...
@@ -65,8 +51,6 @@ def pattern_match(patterns, source_list):
def
main
():
def
main
():
args
=
parse_args
()
args
=
parse_args
()
if
not
ensure_correct_decontamination_params
(
args
):
return
assert
not
args
.
provide_description
# not implemented
assert
not
args
.
provide_description
# not implemented
...
@@ -95,9 +79,7 @@ def main():
...
@@ -95,9 +79,7 @@ def main():
no_cache
=
args
.
no_cache
,
no_cache
=
args
.
no_cache
,
limit
=
args
.
limit
,
limit
=
args
.
limit
,
description_dict
=
description_dict
,
description_dict
=
description_dict
,
decontaminate
=
args
.
decontaminate
,
decontamination_ngrams_path
=
args
.
decontamination_ngrams_path
decontaminate_ngrams_path
=
args
.
decontaminate_ngrams_path
,
decontaminate_ngrams_n_size
=
args
.
decontaminate_ngrams_n_size
)
)
dumped
=
json
.
dumps
(
results
,
indent
=
2
)
dumped
=
json
.
dumps
(
results
,
indent
=
2
)
...
...
pile_statistics.json
0 → 100644
View file @
f495bfb4
{
"Data"
:
"Pile statistics"
,
"Document Count"
:
210607728
,
"Total Pile Characters"
:
421215456
,
"File Start Offsets"
:
[
0
,
7021438
,
14042822
,
21066113
,
28086515
,
35106072
,
42123306
,
49145091
,
56165817
,
63185587
,
70211208
,
77234322
,
84249267
,
91267634
,
98285983
,
105305110
,
112322489
,
119342491
,
126367373
,
133389153
,
140412039
,
147432373
,
154452516
,
161470190
,
168492733
,
175512521
,
182526939
,
189547478
,
196565318
,
203583306
]
}
\ No newline at end of file
scripts/clean_training_data/generate_13_grams.py
View file @
f495bfb4
...
@@ -21,8 +21,10 @@ Arguments
...
@@ -21,8 +21,10 @@ Arguments
"""
"""
import
argparse
import
argparse
import
json
import
pickle
import
pickle
import
os
import
os
import
sys
from
pathlib
import
Path
from
pathlib
import
Path
import
glob
import
glob
import
signal
import
signal
...
@@ -30,32 +32,89 @@ from signal import SIGINT
...
@@ -30,32 +32,89 @@ from signal import SIGINT
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
scripts.clean_training_data
.janitor
import
Janitor
,
word_ngrams
from
lm_eval.decontamination
.janitor
import
Janitor
,
word_ngrams
from
scripts.clean_training_data
.archiver
import
TextArchive
,
Reader
from
lm_eval.decontamination
.archiver
import
TextArchive
,
Reader
import
logging
import
logging
from
tqdm_multiprocess.logger
import
setup_logger_tqdm
from
tqdm_multiprocess.logger
import
setup_logger_tqdm
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
pile_document_count
=
210607728
terminate
=
False
terminate
=
False
def
handler
(
signal_received
,
frame
):
def
handler
(
signal_received
,
frame
):
global
terminate
global
terminate
terminate
=
True
terminate
=
True
def
get_pile
(
directory
):
def
yield_pile
(
start_offsets
=
None
,
checkpoint_offset
=
None
):
reader
=
Reader
()
directory
=
"pile"
for
file
in
glob
.
glob
(
os
.
path
.
join
(
directory
,
f
"*.jsonl.zst*"
)):
for
document
in
reader
.
read
(
file
):
if
not
os
.
path
.
exists
(
directory
):
yield
document
print
(
"We expect the pile archives to be in the 'pile' directory, but this was not found."
)
raise
Exception
(
"Pile directory not found."
)
files
=
list
(
sorted
(
glob
.
glob
(
os
.
path
.
join
(
directory
,
"*.jsonl.zst*"
))))
pile_global_offset
=
0
start_file
=
0
if
checkpoint_offset
:
for
file_i
,
start_offset
in
enumerate
(
start_offsets
):
if
start_offset
>
checkpoint_offset
:
break
def
close_buckets
(
buckets
):
start_file
=
file_i
for
bucket
in
buckets
:
pile_global_offset
=
start_offset
bucket
.
commit
()
for
file_i
,
file
in
enumerate
(
files
):
if
file_i
<
start_file
:
logger
.
info
(
f
"Skipping file
{
file
}
"
)
continue
logger
.
info
(
f
"Reading from pile file:
{
file
}
"
)
reader
=
Reader
()
for
document
in
reader
.
read
(
file
):
yield
(
pile_global_offset
,
document
)
pile_global_offset
+=
1
# Hash buckets > disk backed files. Supports file position checkpointing and resuming
# Allows you to write continuously and checkpoint intermittently. If a failure occurs
# the buckets are simply truncated at your last checkpoint.
class
Buckets
:
def
__init__
(
self
,
directory
,
num_buckets
):
self
.
bucket_files
=
[
os
.
path
.
join
(
directory
,
f
"ngrams_
{
i
}
.bkt.txt"
)
for
i
in
range
(
num_buckets
)]
self
.
buckets
=
list
(
map
(
TextArchive
,
self
.
bucket_files
))
self
.
checkpoint_file
=
os
.
path
.
join
(
directory
,
f
"bucket_offsets.ckpt"
)
if
os
.
path
.
exists
(
self
.
checkpoint_file
):
self
.
bucket_offsets
=
pickle
.
load
(
open
(
self
.
checkpoint_file
,
"rb"
))
else
:
self
.
bucket_offsets
=
[
0
for
i
in
range
(
len
(
self
.
buckets
))]
for
i
,
offset
in
enumerate
(
self
.
bucket_offsets
):
bucket
=
self
.
buckets
[
i
]
bucket
.
fh
.
seek
(
offset
)
bucket
.
fh
.
truncate
()
def
add_data
(
self
,
key
,
value
):
i
=
hash
(
key
)
%
len
(
self
.
buckets
)
bucket
=
self
.
buckets
[
i
]
bucket
.
add_data
(
value
)
def
save_checkpoint
(
self
):
for
bucket
in
self
.
buckets
:
bucket
.
fh
.
flush
()
bucket_offsets
=
[
bucket
.
fh
.
tell
()
for
bucket
in
self
.
buckets
]
pickle
.
dump
(
bucket_offsets
,
open
(
self
.
checkpoint_file
,
"wb"
))
def
close_buckets
(
self
):
for
bucket
in
self
.
buckets
:
bucket
.
commit
()
def
do_ngrams_in_buckets
(
n_value
,
working_directory
,
bucket_count
):
def
do_ngrams_in_buckets
(
n_value
,
working_directory
,
bucket_count
):
pile_statistics
=
json
.
load
(
open
(
"pile_statistics.json"
,
"r"
))
pile_document_count
=
pile_statistics
[
"Document Count"
]
start_offsets
=
pile_statistics
[
"File Start Offsets"
]
output_directory
=
os
.
path
.
join
(
working_directory
,
"output"
)
output_directory
=
os
.
path
.
join
(
working_directory
,
"output"
)
os
.
makedirs
(
output_directory
,
exist_ok
=
True
)
os
.
makedirs
(
output_directory
,
exist_ok
=
True
)
...
@@ -68,49 +127,56 @@ def do_ngrams_in_buckets(n_value, working_directory, bucket_count):
...
@@ -68,49 +127,56 @@ def do_ngrams_in_buckets(n_value, working_directory, bucket_count):
return
return
# Checkpoint
# Checkpoint
checkpoint_file
=
os
.
path
.
join
(
output
_directory
,
f
"
ngram_buck
et
s
.ckpt"
)
checkpoint_file
=
os
.
path
.
join
(
working
_directory
,
f
"
pile_offs
et.ckpt"
)
if
os
.
path
.
exists
(
checkpoint_file
):
if
os
.
path
.
exists
(
checkpoint_file
):
start_id
=
pickle
.
load
(
open
(
checkpoint_file
,
"rb"
))
checkpoint_offset
=
pickle
.
load
(
open
(
checkpoint_file
,
"rb"
))
iterate
=
True
else
:
else
:
start_id
=
0
checkpoint_offset
=
0
iterate
=
False
logger
.
info
(
f
"Starting at pile document index
{
start_id
}
"
)
logger
.
info
(
f
"Starting at pile document index
{
checkpoint_offset
}
"
)
bucket_files
=
[
os
.
path
.
join
(
output_directory
,
f
"ngrams_
{
i
}
.bkt.txt"
)
for
i
in
range
(
bucket_count
)]
buckets
=
Buckets
(
output_directory
,
bucket_count
)
buckets
=
list
(
map
(
TextArchive
,
bucket_files
))
janitor
=
Janitor
()
janitor
=
Janitor
()
current_id
=
0
batch_size
=
1000
batch_size
=
1000
batch_counter
=
0
batch_counter
=
0
with
tqdm
(
total
=
pile_document_count
,
dynamic_ncols
=
True
,
unit
=
"docs"
)
as
progress
:
for
document
in
get_pile
(
working_directory
):
if
current_id
<
start_id
:
if
terminate
:
close_buckets
(
buckets
)
return
current_id
+=
1
with
tqdm
(
total
=
checkpoint_offset
,
dynamic_ncols
=
True
,
unit
=
"docs"
)
as
progress
:
for
offset
,
document
in
yield_pile
(
start_offsets
,
checkpoint_offset
):
if
iterate
:
logger
.
info
(
f
"Iterating to offset
{
checkpoint_offset
}
from
{
offset
}
"
)
progress
.
update
(
offset
)
iterate
=
False
if
offset
<
checkpoint_offset
:
progress
.
update
()
progress
.
update
()
if
terminate
:
return
continue
continue
if
offset
==
checkpoint_offset
:
progress
.
reset
(
total
=
pile_document_count
)
progress
.
update
(
checkpoint_offset
)
# Save checkpoint every "batch_size", only allow terminate after checkpoint
# Save checkpoint every "batch_size", only allow terminate after checkpoint
if
batch_counter
==
batch_size
:
if
batch_counter
==
batch_size
:
progress
.
update
(
batch_size
)
progress
.
update
(
batch_size
)
batch_counter
=
0
batch_counter
=
0
pickle
.
dump
(
current_id
,
open
(
checkpoint_file
,
"wb"
))
buckets
.
save_checkpoint
()
pickle
.
dump
(
offset
,
open
(
checkpoint_file
,
"wb"
))
if
terminate
:
if
terminate
:
close_buckets
(
buckets
)
buckets
.
close_buckets
()
return
return
ngrams
=
word_ngrams
(
janitor
.
normalize_string
(
document
),
n_value
)
ngrams
=
word_ngrams
(
janitor
.
normalize_string
(
document
),
n_value
)
for
ngram
in
ngrams
:
for
ngram
in
ngrams
:
bucket
=
hash
(
ngram
)
%
len
(
buckets
)
buckets
.
add_data
(
ngram
,
f
"
{
ngram
}
{
offset
}
"
)
buckets
[
bucket
].
add_data
(
f
"
{
ngram
}
{
current_id
}
"
)
batch_counter
+=
1
batch_counter
+=
1
current_id
+=
1
close_buckets
(
buckets
)
buckets
.
close_buckets
()
Path
(
done_file
).
touch
()
Path
(
done_file
).
touch
()
...
@@ -120,6 +186,12 @@ parser.add_argument("-n", "--n_value", type=int, default=13)
...
@@ -120,6 +186,12 @@ parser.add_argument("-n", "--n_value", type=int, default=13)
parser
.
add_argument
(
"-buckets"
,
"--bucket_count"
,
type
=
int
,
default
=
500
)
parser
.
add_argument
(
"-buckets"
,
"--bucket_count"
,
type
=
int
,
default
=
500
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
version
=
1.00
print
(
f
"Running version
{
version
}
"
)
if
"PYTHONHASHSEED"
not
in
os
.
environ
or
os
.
environ
[
"PYTHONHASHSEED"
]
!=
"0"
:
print
(
"Please run 'export PYTHONHASHSEED=0' before running generate."
)
sys
.
exit
()
# Handle sigint (ctrl-c) cleanly
# Handle sigint (ctrl-c) cleanly
previous_signal_int
=
signal
.
signal
(
SIGINT
,
handler
)
previous_signal_int
=
signal
.
signal
(
SIGINT
,
handler
)
...
@@ -128,4 +200,8 @@ if __name__ == '__main__':
...
@@ -128,4 +200,8 @@ if __name__ == '__main__':
setup_logger_tqdm
(
logfile_path
)
setup_logger_tqdm
(
logfile_path
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
do_ngrams_in_buckets
(
args
.
n_value
,
args
.
working_directory
,
args
.
bucket_count
)
do_ngrams_in_buckets
(
args
.
n_value
,
args
.
working_directory
,
args
.
bucket_count
)
\ No newline at end of file
info_dict
=
{
"title"
:
"dataset ngrams"
,
"ngram_size"
:
13
}
info_dict_path
=
os
.
path
.
join
(
args
.
working_directory
,
"info.json"
)
json
.
dump
(
info_dict
,
open
(
info_dict_path
,
"w"
))
\ No newline at end of file
scripts/clean_training_data/investigate_pile.py
0 → 100644
View file @
f495bfb4
from
lm_eval.decontamination.archiver
import
Reader
import
os
import
json
from
functools
import
reduce
import
glob
import
tqdm
from
tqdm_multiprocess
import
TqdmMultiProcessPool
def
get_file_stats
(
file_path
,
tqdm_func
,
global_tqdm
):
reader
=
Reader
()
total_documents
=
0
total_size
=
0
update_frequency
=
10000
current_file_position
=
0
with
tqdm_func
(
total
=
os
.
path
.
getsize
(
file_path
),
dynamic_ncols
=
True
,
unit
=
"byte"
,
unit_scale
=
1
)
as
progress
:
for
document
in
reader
.
read
(
file_path
,
get_meta
=
True
):
total_size
+=
len
(
document
)
total_documents
+=
1
if
total_documents
%
update_frequency
==
0
:
new_file_pos
=
reader
.
fh
.
tell
()
bytes_read
=
new_file_pos
-
current_file_position
current_file_position
=
new_file_pos
progress
.
update
(
bytes_read
)
global_tqdm
.
update
(
bytes_read
)
return
(
total_documents
,
total_size
)
def
get_files
():
directory
=
"pile"
files
=
list
(
sorted
(
glob
.
glob
(
os
.
path
.
join
(
directory
,
"*.jsonl.zst*"
))))
print
(
files
)
return
files
def
get_stats
():
files
=
get_files
()
total_size_bytes
=
sum
(
map
(
lambda
x
:
os
.
path
.
getsize
(
x
),
files
))
pool
=
TqdmMultiProcessPool
(
4
)
global_tqdm
=
tqdm
.
tqdm
(
total
=
total_size_bytes
,
dynamic_ncols
=
True
,
unit
=
"byte"
,
unit_scale
=
1
)
# Generate minhashes with pool
tasks
=
[(
get_file_stats
,
(
file
,))
for
file
in
files
]
on_done
=
lambda
_
:
None
on_error
=
lambda
_
:
None
results
=
pool
.
map
(
global_tqdm
,
tasks
,
on_error
,
on_done
)
total_documents
,
total_size
=
reduce
(
lambda
x
,
y
:
(
x
[
0
]
+
y
[
0
],
x
[
1
]
+
y
[
1
]),
results
)
start_offsets
=
[]
current_offset
=
0
for
file_document_count
,
_
in
results
:
start_offsets
.
append
(
current_offset
)
current_offset
+=
file_document_count
return
(
total_documents
,
total_size
,
start_offsets
)
if
__name__
==
'__main__'
:
version
=
1.01
print
(
f
"Running version
{
version
}
"
)
stats_file_path
=
"pile_statistics.json"
if
os
.
path
.
exists
(
stats_file_path
):
stats
=
json
.
load
(
open
(
stats_file_path
,
"r"
))
else
:
document_count
,
total_document_size_chars
,
start_offsets
=
get_stats
()
stats
=
{
"Data"
:
"Pile statistics"
,
"Document Count"
:
document_count
,
"Total Pile Characters"
:
total_document_size_chars
,
"File Start Offsets"
:
start_offsets
}
json
.
dump
(
stats
,
open
(
stats_file_path
,
"w"
),
indent
=
4
)
print
(
f
"document_count:
{
stats
[
'Document Count'
]
}
"
)
print
(
f
"total_chars:
{
stats
[
'Total Pile Characters'
]
}
"
)
print
(
f
"start_offsets:
{
stats
[
'File Start Offsets'
]
}
"
)
tests/test_generate_13_grams.py
View file @
f495bfb4
...
@@ -3,12 +3,14 @@ from collections import Counter
...
@@ -3,12 +3,14 @@ from collections import Counter
import
shutil
import
shutil
import
glob
import
glob
from
scripts.clean_training_data
.janitor
import
*
from
lm_eval.decontamination
.janitor
import
*
from
scripts.clean_training_data.generate_13_grams
import
do_ngrams_in_buckets
from
scripts.clean_training_data.generate_13_grams
import
do_ngrams_in_buckets
from
scripts.clean_training_data
.archiver
import
Archive
,
TextReader
from
lm_eval.decontamination
.archiver
import
Archive
,
TextReader
import
logging
logger
=
logging
.
getLogger
(
__name__
)
def
test_generate_13_grams_1
():
def
test_generate_13_grams_1
(
caplog
):
data
=
"""A goose (plural geese) is a bird of any of several waterfowl species in the family Anatidae.
data
=
"""A goose (plural geese) is a bird of any of several waterfowl species in the family Anatidae.
This group comprises the genera Anser (the grey geese and white geese) and Branta (the black geese).
This group comprises the genera Anser (the grey geese and white geese) and Branta (the black geese).
Some other birds, mostly related to the shelducks, have "goose" as part of their names.
Some other birds, mostly related to the shelducks, have "goose" as part of their names.
...
@@ -22,6 +24,7 @@ def test_generate_13_grams_1():
...
@@ -22,6 +24,7 @@ def test_generate_13_grams_1():
data
=
data
+
data
data
=
data
+
data
# Simple Generation
# Simple Generation
print
(
"simple generation"
)
n
=
13
n
=
13
janitor
=
Janitor
()
janitor
=
Janitor
()
ngrams
=
word_ngrams
(
janitor
.
normalize_string
(
data
),
n
)
ngrams
=
word_ngrams
(
janitor
.
normalize_string
(
data
),
n
)
...
@@ -31,22 +34,26 @@ def test_generate_13_grams_1():
...
@@ -31,22 +34,26 @@ def test_generate_13_grams_1():
# print(comparison)
# print(comparison)
# Generating into buckets
# Generating into buckets
print
(
"bucket generation"
)
test_working_directory
=
"test_generate_13_grams"
test_working_directory
=
"test_generate_13_grams"
output_directory
=
os
.
path
.
join
(
test_working_directory
,
"output"
)
try
:
try
:
shutil
.
rmtree
(
output
_directory
)
shutil
.
rmtree
(
test_working
_directory
)
except
FileNotFoundError
:
except
FileNotFoundError
:
pass
pass
os
.
makedirs
(
test_working_directory
,
exist_ok
=
True
)
os
.
makedirs
(
test_working_directory
)
archive
=
Archive
(
os
.
path
.
join
(
test_working_directory
,
"test.jsonl.zst"
))
assert
(
not
os
.
path
.
exists
(
"pile"
))
os
.
makedirs
(
"pile"
)
archive
=
Archive
(
os
.
path
.
join
(
"pile"
,
"test.jsonl.zst"
))
archive
.
add_data
(
data
)
archive
.
add_data
(
data
)
archive
.
commit
()
archive
.
commit
()
bucket_count
=
4
bucket_count
=
4
do_ngrams_in_buckets
(
n
,
test_working_directory
,
bucket_count
)
do_ngrams_in_buckets
(
n
,
test_working_directory
,
bucket_count
)
# Rebuild from buckets
# Rebuild from buckets
print
(
"rebuild"
)
rebuilt_ngrams
=
[]
rebuilt_ngrams
=
[]
bucket_file_paths
=
glob
.
glob
(
os
.
path
.
join
(
test_working_directory
,
"output"
,
f
"*.bkt.txt"
))
bucket_file_paths
=
glob
.
glob
(
os
.
path
.
join
(
test_working_directory
,
"output"
,
f
"*.bkt.txt"
))
for
bucket_file_path
in
bucket_file_paths
:
for
bucket_file_path
in
bucket_file_paths
:
reader
=
TextReader
(
bucket_file_path
)
reader
=
TextReader
(
bucket_file_path
)
...
@@ -56,6 +63,7 @@ def test_generate_13_grams_1():
...
@@ -56,6 +63,7 @@ def test_generate_13_grams_1():
rebuilt_ngrams
.
append
(
ngram
)
rebuilt_ngrams
.
append
(
ngram
)
# Compare
# Compare
print
(
"compare"
)
result_counter
=
Counter
(
rebuilt_ngrams
)
result_counter
=
Counter
(
rebuilt_ngrams
)
# print(len(result_counter))
# print(len(result_counter))
# print(len(comparison_counter))
# print(len(comparison_counter))
...
...
tests/test_janitor.py
View file @
f495bfb4
import
re
import
re
from
collections
import
defaultdict
from
collections
import
defaultdict
from
scripts.clean_training_data
.janitor
import
*
from
lm_eval.decontamination
.janitor
import
*
def
simple_ngram
(
sequence
,
n
):
def
simple_ngram
(
sequence
,
n
):
ngrams
=
list
()
ngrams
=
list
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment