Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
b2460099
Commit
b2460099
authored
Feb 24, 2022
by
Leo Gao
Browse files
pull stuff into lm_eval.decontamination
parent
98d75af0
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
438 additions
and
2 deletions
+438
-2
lm_eval/decontamination.py
lm_eval/decontamination.py
+435
-0
lm_eval/evaluator.py
lm_eval/evaluator.py
+2
-2
scripts/clean_training_data/janitor.py
scripts/clean_training_data/janitor.py
+1
-0
No files found.
scripts/clean_training_data/
contamination.py
→
lm_eval/de
contamination.py
View file @
b2460099
import
mmap
import
string
import
time
import
time
import
random
import
random
import
pickle
import
pickle
import
glob
import
glob
import
os
import
os
import
collections
import
collections
import
re
from
scripts.clean_training_data.janitor
import
Janitor
,
word_ngrams
import
tqdm
from
scripts.clean_training_data.archiver
import
ZStdTextReader
try
:
import
janitor_util
JANITOR_CPP
=
True
except
Exception
as
e
:
print
(
"WARNING: C++ module could not be loaded. Janitor running in python mode"
)
JANITOR_CPP
=
False
# Was used for testing the evaluator decoupled from the full logic below
# Was used for testing the evaluator decoupled from the full logic below
def
get_train_overlap_stub
(
docs
,
ngrams_path
,
ngrams_n_size
):
def
get_train_overlap_stub
(
docs
,
ngrams_path
,
ngrams_n_size
):
...
@@ -147,3 +156,280 @@ def get_train_overlap(docs_by_task_set, ngrams_path, ngrams_n_size, limit):
...
@@ -147,3 +156,280 @@ def get_train_overlap(docs_by_task_set, ngrams_path, ngrams_n_size, limit):
return
{
task_name
:
doc_ids
for
(
task_name
,
task_set
),
doc_ids
in
duplicates
.
items
()}
return
{
task_name
:
doc_ids
for
(
task_name
,
task_set
),
doc_ids
in
duplicates
.
items
()}
# Implementation from nltk source
# https://www.nltk.org/_modules/nltk/util.html
def
form_ngrams
(
sequence
,
n
):
history
=
[]
while
n
>
1
:
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
try
:
next_item
=
next
(
sequence
)
except
StopIteration
:
# no more data, terminate the generator
return
history
.
append
(
next_item
)
n
-=
1
for
item
in
sequence
:
history
.
append
(
item
)
yield
tuple
(
history
)
del
history
[
0
]
def
word_ngrams
(
s
,
n
):
"""Splits a string into ngram words"""
tokens
=
s
.
split
()
# not a generator :(
ngram_seqs
=
form_ngrams
(
iter
(
tokens
),
n
)
return
(
" "
.
join
(
ngram
)
for
ngram
in
ngram_seqs
)
# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
def
split_indices
(
s
):
"""Splits a string on whitespaces and records the indices of each in the original string.
@:return generator((word, (start_idx, end_idx)), ...)
"""
return
((
m
.
group
(
0
),
(
m
.
start
(),
m
.
end
()
-
1
))
for
m
in
re
.
finditer
(
r
'\S+'
,
s
))
def
word_ngrams_indices
(
s
,
n
):
"""Splits a string into pairs of (ngram words, their start/end indices)"""
tokens_with_indices
=
split_indices
(
s
)
# Generator of ngrams of (word, idx_pairs)
# (
# [(word, (start,end)), (word, (start, end))...],
# [(word, (start, end)), ...],
# ...
# )
ngram_seqs_with_indices
=
form_ngrams
(
tokens_with_indices
,
n
)
# Generator of pairs of word and index ngrams
# (
# ([word, word, ...], [(start,end), (start,end), ...]),
# ...
# )
ngram_indices_pairs
=
(
zip
(
*
ngram_with_indices
)
for
ngram_with_indices
in
ngram_seqs_with_indices
)
# Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
return
((
" "
.
join
(
ngram_seq
),
(
indices
[
0
][
0
],
indices
[
-
1
][
1
]))
for
ngram_seq
,
indices
in
ngram_indices_pairs
)
class
Janitor
:
# FIXME delete_chars: Should anything else go here? Special chars?
def
__init__
(
self
,
ngram_n
=
13
,
window_to_remove
=
200
,
too_dirty_cutoff
=
10
,
minimum_slice_length
=
200
,
delete_chars
=
string
.
punctuation
):
self
.
ngram_n
=
ngram_n
self
.
window_to_remove
=
window_to_remove
self
.
too_dirty_cutoff
=
too_dirty_cutoff
self
.
minimum_slice_length
=
minimum_slice_length
self
.
delete_chars
=
delete_chars
self
.
dirt_ngrams
=
set
()
# If in python, we'll translate uppercase to lowercase and delete naughty characters.
# This is fast by python standards
# https://stackoverflow.com/questions/638893/what-is-the-most-efficient-way-in-python-to-convert-a-string-to-all-lowercase-st
self
.
translation_table
=
str
.
maketrans
(
string
.
ascii_lowercase
+
string
.
ascii_uppercase
,
# These characters
string
.
ascii_lowercase
*
2
,
# Become these characters
self
.
delete_chars
# These are deleted
)
##############
# I/O for saving contamination ngrams
##############
def
save_contamination_ngrams
(
self
,
filename
):
with
open
(
filename
,
'wb'
)
as
fp
:
pickle
.
dump
(
filename
,
fp
)
def
load_contamination_ngrams
(
self
,
filename
):
with
open
(
filename
,
'rb'
)
as
fp
:
self
.
dirt_ngrams
=
pickle
.
load
(
fp
)
##############
# Call these :)
##############
def
register_contaminant
(
self
,
dirt_string
):
"""Register a string as contamination to be removed, e.g. a test set
This breaks the dirt_string into ngrams to store for future cleaning"""
if
JANITOR_CPP
:
return
self
.
register_contaminant_cpp
(
dirt_string
)
else
:
print
(
"WARNING: Janitor running in python mode"
)
return
self
.
register_contaminant_python
(
dirt_string
)
def
clean
(
self
,
dirty_string
):
"""Clean a string (e.g. a training set) by removing all ngrams previously
reigstered as contaminants. Returns a list of clean chunks, or empty if
the string was too dirty"""
if
JANITOR_CPP
:
return
self
.
clean_cpp
(
dirty_string
)
else
:
print
(
"WARNING: Janitor running in python mode"
)
return
self
.
clean_python
(
dirty_string
)
def
_split_chunks
(
self
,
dirty_string
,
dirty_parts
):
clean_chunks
=
[]
splice_idx
=
0
end
=
-
1
for
i
,
(
ngram
,
start
,
end
)
in
enumerate
(
dirty_parts
):
if
i
>=
self
.
too_dirty_cutoff
:
return
[]
start
=
max
(
0
,
start
-
self
.
window_to_remove
)
end
=
min
(
len
(
dirty_string
),
end
+
self
.
window_to_remove
)
if
start
-
splice_idx
>
self
.
minimum_slice_length
:
clean_chunks
.
append
(
dirty_string
[
splice_idx
:
start
])
splice_idx
=
end
if
end
<
len
(
dirty_string
)
-
self
.
minimum_slice_length
:
clean_chunks
.
append
(
dirty_string
[
end
+
1
:])
return
clean_chunks
##############
# Fast C++
##############
def
register_contaminant_cpp
(
self
,
dirt_string
):
self
.
dirt_ngrams
.
update
(
janitor_util
.
clean_ngram
(
dirt_string
,
self
.
delete_chars
,
self
.
ngram_n
))
def
clean_cpp
(
self
,
dirty_string
):
contamination_indices
=
janitor_util
.
clean_ngram_with_indices
(
dirty_string
,
self
.
delete_chars
,
self
.
ngram_n
)
return
self
.
_split_chunks
(
dirty_string
,
contamination_indices
)
##############
# Slow python
##############
def
normalize_string
(
self
,
s
):
return
s
.
translate
(
self
.
translation_table
)
def
register_contaminant_python
(
self
,
dirt_string
):
self
.
dirt_ngrams
.
update
(
word_ngrams
(
self
.
normalize_string
(
dirt_string
),
self
.
ngram_n
))
def
clean_python
(
self
,
dirty_string
):
contamination_indices
=
(
(
None
,
*
idx_pair
)
for
dirty_ngram
,
idx_pair
in
word_ngrams_indices
(
dirty_string
,
self
.
ngram_n
)
if
self
.
normalize_string
(
dirty_ngram
)
in
self
.
dirt_ngrams
)
return
self
.
_split_chunks
(
dirty_string
,
contamination_indices
)
# Implementation from nltk source
# https://www.nltk.org/_modules/nltk/util.html
def
form_ngrams
(
sequence
,
n
):
history
=
[]
while
n
>
1
:
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
try
:
next_item
=
next
(
sequence
)
except
StopIteration
:
# no more data, terminate the generator
return
history
.
append
(
next_item
)
n
-=
1
for
item
in
sequence
:
history
.
append
(
item
)
yield
tuple
(
history
)
del
history
[
0
]
def
word_ngrams
(
s
,
n
):
"""Splits a string into ngram words"""
tokens
=
s
.
split
()
# not a generator :(
ngram_seqs
=
form_ngrams
(
iter
(
tokens
),
n
)
return
(
" "
.
join
(
ngram
)
for
ngram
in
ngram_seqs
)
# Simple text reader and writer with same interface as above
class
TextArchive
:
def
__init__
(
self
,
file_path
,
mode
=
"ab"
):
self
.
file_path
=
file_path
dir_name
=
os
.
path
.
dirname
(
file_path
)
if
dir_name
:
os
.
makedirs
(
dir_name
,
exist_ok
=
True
)
self
.
fh
=
open
(
self
.
file_path
,
mode
)
def
add_data
(
self
,
data
,
meta
=
{}):
self
.
fh
.
write
(
data
.
encode
(
'UTF-8'
)
+
b
'
\n
'
)
def
commit
(
self
):
self
.
fh
.
flush
()
self
.
fh
.
close
()
class
TextReader
:
def
__init__
(
self
,
file_path
):
self
.
file_path
=
file_path
# Optimized mmap read with infrequent tqdm updates to maintain speed
# Tested up to 250MB/s.
def
read_tqdm
(
self
,
update_frequency
=
10000
):
current_file_position
=
0
line_counter
=
0
with
open
(
self
.
file_path
,
'r'
)
as
fh
,
\
tqdm
.
tqdm
(
total
=
os
.
path
.
getsize
(
self
.
file_path
),
dynamic_ncols
=
True
,
unit
=
"byte"
,
unit_scale
=
1
)
as
progress
:
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
line
=
line
.
decode
(
"utf-8"
)
line_counter
+=
1
if
line_counter
==
update_frequency
:
new_file_pos
=
mmap_obj
.
tell
()
bytes_read
=
new_file_pos
-
current_file_position
current_file_position
=
new_file_pos
progress
.
update
(
bytes_read
)
line_counter
=
0
yield
line
[:
-
1
]
def
read_and_tell
(
self
):
current_file_position
=
0
with
open
(
self
.
file_path
,
'r'
,
encoding
=
"utf8"
)
as
fh
:
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
line
=
line
.
decode
(
"utf-8"
)
new_file_pos
=
mmap_obj
.
tell
()
raw_bytes_read
=
new_file_pos
-
current_file_position
current_file_position
=
new_file_pos
yield
line
[:
-
1
],
raw_bytes_read
def
read
(
self
):
with
open
(
self
.
file_path
,
'r'
,
encoding
=
"utf8"
)
as
fh
:
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
line
=
line
.
decode
(
"utf-8"
)
yield
line
[:
-
1
]
def
read_slow
(
self
):
with
open
(
self
.
file_path
,
'r'
,
encoding
=
"utf8"
)
as
fh
:
while
True
:
line
=
fh
.
readline
()
if
line
==
-
1
or
line
==
""
:
break
else
:
yield
line
[:
-
1
]
# Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader.
class
ZStdTextReader
:
def
__init__
(
self
,
file
):
self
.
file
=
file
def
read_tqdm
(
self
):
decompressed_file
=
self
.
file
[:
-
4
]
print
(
"Decompressing file, please wait..."
)
os
.
system
(
f
"zstd -d
{
self
.
file
}
"
)
# linux decompress is faster
reader
=
TextReader
(
decompressed_file
)
yield
from
reader
.
read_tqdm
()
os
.
remove
(
decompressed_file
)
lm_eval/evaluator.py
View file @
b2460099
...
@@ -5,7 +5,7 @@ import lm_eval.metrics
...
@@ -5,7 +5,7 @@ import lm_eval.metrics
import
lm_eval.models
import
lm_eval.models
import
lm_eval.tasks
import
lm_eval.tasks
import
lm_eval.base
import
lm_eval.base
from
scripts.clean_training_data.contamination
import
get_train_overlap
import
lm_eval.decontamination
import
numpy
as
np
import
numpy
as
np
from
lm_eval.utils
import
positional_deprecated
from
lm_eval.utils
import
positional_deprecated
...
@@ -193,7 +193,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
...
@@ -193,7 +193,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# Compare all tasks/sets at once to ensure a single training set scan
# Compare all tasks/sets at once to ensure a single training set scan
if
decontaminate
:
if
decontaminate
:
print
(
"Finding train/test overlap, please wait..."
)
print
(
"Finding train/test overlap, please wait..."
)
overlaps
=
get_train_overlap
(
docs_for_decontamination
,
decontaminate_ngrams_path
,
decontaminate_ngrams_n_size
,
limit
)
overlaps
=
lm_eval
.
decontamination
.
get_train_overlap
(
docs_for_decontamination
,
decontaminate_ngrams_path
,
decontaminate_ngrams_n_size
,
limit
)
# all responses for each (task, doc)
# all responses for each (task, doc)
process_res_queue
=
collections
.
defaultdict
(
list
)
process_res_queue
=
collections
.
defaultdict
(
list
)
...
...
scripts/clean_training_data/janitor.py
View file @
b2460099
...
@@ -4,6 +4,7 @@ import timeit
...
@@ -4,6 +4,7 @@ import timeit
import
pickle
import
pickle
import
traceback
import
traceback
from
pprint
import
pprint
from
pprint
import
pprint
from
lm_eval.decontamination
import
word_ngrams
# This is a cpp module. Compile janitor_util.cpp with:
# This is a cpp module. Compile janitor_util.cpp with:
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment