Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
b2460099
Commit
b2460099
authored
Feb 24, 2022
by
Leo Gao
Browse files
pull stuff into lm_eval.decontamination
parent
98d75af0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
438 additions
and
2 deletions
+438
-2
lm_eval/decontamination.py
lm_eval/decontamination.py
+435
-0
lm_eval/evaluator.py
lm_eval/evaluator.py
+2
-2
scripts/clean_training_data/janitor.py
scripts/clean_training_data/janitor.py
+1
-0
No files found.
scripts/clean_training_data/
contamination.py
→
lm_eval/de
contamination.py
View file @
b2460099
import
mmap
import
string
import
time
import
random
import
pickle
import
glob
import
os
import
collections
import
re
from
scripts.clean_training_data.janitor
import
Janitor
,
word_ngrams
from
scripts.clean_training_data.archiver
import
ZStdTextReader
import
tqdm
try
:
import
janitor_util
JANITOR_CPP
=
True
except
Exception
as
e
:
print
(
"WARNING: C++ module could not be loaded. Janitor running in python mode"
)
JANITOR_CPP
=
False
# Was used for testing the evaluator decoupled from the full logic below
def
get_train_overlap_stub
(
docs
,
ngrams_path
,
ngrams_n_size
):
...
...
@@ -147,3 +156,280 @@ def get_train_overlap(docs_by_task_set, ngrams_path, ngrams_n_size, limit):
return
{
task_name
:
doc_ids
for
(
task_name
,
task_set
),
doc_ids
in
duplicates
.
items
()}
# Implementation from nltk source
# https://www.nltk.org/_modules/nltk/util.html
def
form_ngrams
(
sequence
,
n
):
history
=
[]
while
n
>
1
:
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
try
:
next_item
=
next
(
sequence
)
except
StopIteration
:
# no more data, terminate the generator
return
history
.
append
(
next_item
)
n
-=
1
for
item
in
sequence
:
history
.
append
(
item
)
yield
tuple
(
history
)
del
history
[
0
]
def
word_ngrams
(
s
,
n
):
"""Splits a string into ngram words"""
tokens
=
s
.
split
()
# not a generator :(
ngram_seqs
=
form_ngrams
(
iter
(
tokens
),
n
)
return
(
" "
.
join
(
ngram
)
for
ngram
in
ngram_seqs
)
# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
def
split_indices
(
s
):
"""Splits a string on whitespaces and records the indices of each in the original string.
@:return generator((word, (start_idx, end_idx)), ...)
"""
return
((
m
.
group
(
0
),
(
m
.
start
(),
m
.
end
()
-
1
))
for
m
in
re
.
finditer
(
r
'\S+'
,
s
))
def
word_ngrams_indices
(
s
,
n
):
"""Splits a string into pairs of (ngram words, their start/end indices)"""
tokens_with_indices
=
split_indices
(
s
)
# Generator of ngrams of (word, idx_pairs)
# (
# [(word, (start,end)), (word, (start, end))...],
# [(word, (start, end)), ...],
# ...
# )
ngram_seqs_with_indices
=
form_ngrams
(
tokens_with_indices
,
n
)
# Generator of pairs of word and index ngrams
# (
# ([word, word, ...], [(start,end), (start,end), ...]),
# ...
# )
ngram_indices_pairs
=
(
zip
(
*
ngram_with_indices
)
for
ngram_with_indices
in
ngram_seqs_with_indices
)
# Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
return
((
" "
.
join
(
ngram_seq
),
(
indices
[
0
][
0
],
indices
[
-
1
][
1
]))
for
ngram_seq
,
indices
in
ngram_indices_pairs
)
class
Janitor
:
# FIXME delete_chars: Should anything else go here? Special chars?
def
__init__
(
self
,
ngram_n
=
13
,
window_to_remove
=
200
,
too_dirty_cutoff
=
10
,
minimum_slice_length
=
200
,
delete_chars
=
string
.
punctuation
):
self
.
ngram_n
=
ngram_n
self
.
window_to_remove
=
window_to_remove
self
.
too_dirty_cutoff
=
too_dirty_cutoff
self
.
minimum_slice_length
=
minimum_slice_length
self
.
delete_chars
=
delete_chars
self
.
dirt_ngrams
=
set
()
# If in python, we'll translate uppercase to lowercase and delete naughty characters.
# This is fast by python standards
# https://stackoverflow.com/questions/638893/what-is-the-most-efficient-way-in-python-to-convert-a-string-to-all-lowercase-st
self
.
translation_table
=
str
.
maketrans
(
string
.
ascii_lowercase
+
string
.
ascii_uppercase
,
# These characters
string
.
ascii_lowercase
*
2
,
# Become these characters
self
.
delete_chars
# These are deleted
)
##############
# I/O for saving contamination ngrams
##############
def
save_contamination_ngrams
(
self
,
filename
):
with
open
(
filename
,
'wb'
)
as
fp
:
pickle
.
dump
(
filename
,
fp
)
def
load_contamination_ngrams
(
self
,
filename
):
with
open
(
filename
,
'rb'
)
as
fp
:
self
.
dirt_ngrams
=
pickle
.
load
(
fp
)
##############
# Call these :)
##############
def
register_contaminant
(
self
,
dirt_string
):
"""Register a string as contamination to be removed, e.g. a test set
This breaks the dirt_string into ngrams to store for future cleaning"""
if
JANITOR_CPP
:
return
self
.
register_contaminant_cpp
(
dirt_string
)
else
:
print
(
"WARNING: Janitor running in python mode"
)
return
self
.
register_contaminant_python
(
dirt_string
)
def
clean
(
self
,
dirty_string
):
"""Clean a string (e.g. a training set) by removing all ngrams previously
reigstered as contaminants. Returns a list of clean chunks, or empty if
the string was too dirty"""
if
JANITOR_CPP
:
return
self
.
clean_cpp
(
dirty_string
)
else
:
print
(
"WARNING: Janitor running in python mode"
)
return
self
.
clean_python
(
dirty_string
)
def
_split_chunks
(
self
,
dirty_string
,
dirty_parts
):
clean_chunks
=
[]
splice_idx
=
0
end
=
-
1
for
i
,
(
ngram
,
start
,
end
)
in
enumerate
(
dirty_parts
):
if
i
>=
self
.
too_dirty_cutoff
:
return
[]
start
=
max
(
0
,
start
-
self
.
window_to_remove
)
end
=
min
(
len
(
dirty_string
),
end
+
self
.
window_to_remove
)
if
start
-
splice_idx
>
self
.
minimum_slice_length
:
clean_chunks
.
append
(
dirty_string
[
splice_idx
:
start
])
splice_idx
=
end
if
end
<
len
(
dirty_string
)
-
self
.
minimum_slice_length
:
clean_chunks
.
append
(
dirty_string
[
end
+
1
:])
return
clean_chunks
##############
# Fast C++
##############
def
register_contaminant_cpp
(
self
,
dirt_string
):
self
.
dirt_ngrams
.
update
(
janitor_util
.
clean_ngram
(
dirt_string
,
self
.
delete_chars
,
self
.
ngram_n
))
def
clean_cpp
(
self
,
dirty_string
):
contamination_indices
=
janitor_util
.
clean_ngram_with_indices
(
dirty_string
,
self
.
delete_chars
,
self
.
ngram_n
)
return
self
.
_split_chunks
(
dirty_string
,
contamination_indices
)
##############
# Slow python
##############
def
normalize_string
(
self
,
s
):
return
s
.
translate
(
self
.
translation_table
)
def
register_contaminant_python
(
self
,
dirt_string
):
self
.
dirt_ngrams
.
update
(
word_ngrams
(
self
.
normalize_string
(
dirt_string
),
self
.
ngram_n
))
def
clean_python
(
self
,
dirty_string
):
contamination_indices
=
(
(
None
,
*
idx_pair
)
for
dirty_ngram
,
idx_pair
in
word_ngrams_indices
(
dirty_string
,
self
.
ngram_n
)
if
self
.
normalize_string
(
dirty_ngram
)
in
self
.
dirt_ngrams
)
return
self
.
_split_chunks
(
dirty_string
,
contamination_indices
)
# Implementation from nltk source
# https://www.nltk.org/_modules/nltk/util.html
def
form_ngrams
(
sequence
,
n
):
history
=
[]
while
n
>
1
:
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
try
:
next_item
=
next
(
sequence
)
except
StopIteration
:
# no more data, terminate the generator
return
history
.
append
(
next_item
)
n
-=
1
for
item
in
sequence
:
history
.
append
(
item
)
yield
tuple
(
history
)
del
history
[
0
]
def
word_ngrams
(
s
,
n
):
"""Splits a string into ngram words"""
tokens
=
s
.
split
()
# not a generator :(
ngram_seqs
=
form_ngrams
(
iter
(
tokens
),
n
)
return
(
" "
.
join
(
ngram
)
for
ngram
in
ngram_seqs
)
# Simple text reader and writer with same interface as above
class
TextArchive
:
def
__init__
(
self
,
file_path
,
mode
=
"ab"
):
self
.
file_path
=
file_path
dir_name
=
os
.
path
.
dirname
(
file_path
)
if
dir_name
:
os
.
makedirs
(
dir_name
,
exist_ok
=
True
)
self
.
fh
=
open
(
self
.
file_path
,
mode
)
def
add_data
(
self
,
data
,
meta
=
{}):
self
.
fh
.
write
(
data
.
encode
(
'UTF-8'
)
+
b
'
\n
'
)
def
commit
(
self
):
self
.
fh
.
flush
()
self
.
fh
.
close
()
class
TextReader
:
def
__init__
(
self
,
file_path
):
self
.
file_path
=
file_path
# Optimized mmap read with infrequent tqdm updates to maintain speed
# Tested up to 250MB/s.
def
read_tqdm
(
self
,
update_frequency
=
10000
):
current_file_position
=
0
line_counter
=
0
with
open
(
self
.
file_path
,
'r'
)
as
fh
,
\
tqdm
.
tqdm
(
total
=
os
.
path
.
getsize
(
self
.
file_path
),
dynamic_ncols
=
True
,
unit
=
"byte"
,
unit_scale
=
1
)
as
progress
:
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
line
=
line
.
decode
(
"utf-8"
)
line_counter
+=
1
if
line_counter
==
update_frequency
:
new_file_pos
=
mmap_obj
.
tell
()
bytes_read
=
new_file_pos
-
current_file_position
current_file_position
=
new_file_pos
progress
.
update
(
bytes_read
)
line_counter
=
0
yield
line
[:
-
1
]
def
read_and_tell
(
self
):
current_file_position
=
0
with
open
(
self
.
file_path
,
'r'
,
encoding
=
"utf8"
)
as
fh
:
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
line
=
line
.
decode
(
"utf-8"
)
new_file_pos
=
mmap_obj
.
tell
()
raw_bytes_read
=
new_file_pos
-
current_file_position
current_file_position
=
new_file_pos
yield
line
[:
-
1
],
raw_bytes_read
def
read
(
self
):
with
open
(
self
.
file_path
,
'r'
,
encoding
=
"utf8"
)
as
fh
:
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
line
=
line
.
decode
(
"utf-8"
)
yield
line
[:
-
1
]
def
read_slow
(
self
):
with
open
(
self
.
file_path
,
'r'
,
encoding
=
"utf8"
)
as
fh
:
while
True
:
line
=
fh
.
readline
()
if
line
==
-
1
or
line
==
""
:
break
else
:
yield
line
[:
-
1
]
# Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader.
class
ZStdTextReader
:
def
__init__
(
self
,
file
):
self
.
file
=
file
def
read_tqdm
(
self
):
decompressed_file
=
self
.
file
[:
-
4
]
print
(
"Decompressing file, please wait..."
)
os
.
system
(
f
"zstd -d
{
self
.
file
}
"
)
# linux decompress is faster
reader
=
TextReader
(
decompressed_file
)
yield
from
reader
.
read_tqdm
()
os
.
remove
(
decompressed_file
)
lm_eval/evaluator.py
View file @
b2460099
...
...
@@ -5,7 +5,7 @@ import lm_eval.metrics
import
lm_eval.models
import
lm_eval.tasks
import
lm_eval.base
from
scripts.clean_training_data.contamination
import
get_train_overlap
import
lm_eval.decontamination
import
numpy
as
np
from
lm_eval.utils
import
positional_deprecated
...
...
@@ -193,7 +193,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# Compare all tasks/sets at once to ensure a single training set scan
if
decontaminate
:
print
(
"Finding train/test overlap, please wait..."
)
overlaps
=
get_train_overlap
(
docs_for_decontamination
,
decontaminate_ngrams_path
,
decontaminate_ngrams_n_size
,
limit
)
overlaps
=
lm_eval
.
decontamination
.
get_train_overlap
(
docs_for_decontamination
,
decontaminate_ngrams_path
,
decontaminate_ngrams_n_size
,
limit
)
# all responses for each (task, doc)
process_res_queue
=
collections
.
defaultdict
(
list
)
...
...
scripts/clean_training_data/janitor.py
View file @
b2460099
...
...
@@ -4,6 +4,7 @@ import timeit
import
pickle
import
traceback
from
pprint
import
pprint
from
lm_eval.decontamination
import
word_ngrams
# This is a cpp module. Compile janitor_util.cpp with:
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment