Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
c44f7622
Commit
c44f7622
authored
Mar 04, 2021
by
Mostofa Patwary
Browse files
Many more features added
parent
6013e23c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
184 additions
and
81 deletions
+184
-81
tools/openwebtext/filter_ngrams.py
tools/openwebtext/filter_ngrams.py
+184
-81
No files found.
tools/openwebtext/filter_ngrams.py
View file @
c44f7622
...
...
@@ -24,6 +24,7 @@ from functools import partial
import
json
import
multiprocessing
import
nltk
import
pickle
import
re
import
string
import
sys
...
...
@@ -61,11 +62,23 @@ def split_text(text, start_position, remove_char_each_side, seq):
return
text_first
,
text_second
def
check_and_clean_text
(
args
,
words
,
ngrams
,
text
,
start_position
,
\
text_buf_ngram_free
,
text_buf
):
text_buf_ngram_free
,
text_buf
,
local_ngram
):
seq
=
" "
.
join
(
words
)
if
seq
in
ngrams
:
print
(
" [matched]: {}"
.
format
(
seq
),
flush
=
True
)
#print(" [matched]: {}".format(seq), flush=True)
if
args
.
get_ngram_freq_only
:
# increase freq of this seq and then only consider the later part
# of the text for further processing
if
seq
in
local_ngram
:
local_ngram
[
seq
]
+=
1
else
:
local_ngram
[
seq
]
=
1
#print(" [increased]: {} {}".format(seq, ngrams[seq]), flush=True)
if
(
start_position
+
len
(
seq
)
+
1
)
<
len
(
text
):
text_buf
.
append
(
text
[
start_position
+
len
(
seq
)
+
1
:
len
(
text
)])
return
False
# split the text
text_first
,
text_second
=
split_text
(
text
,
start_position
,
\
...
...
@@ -84,6 +97,7 @@ def check_and_clean_text(args, words, ngrams, text, start_position, \
# ngram free
return
True
def
free_ngram
(
line
,
args
,
key
,
ngrams
,
ngrams_freq_sorted
):
# remove all the ngrams
...
...
@@ -95,18 +109,19 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
text_buf
=
[]
text_buf_ngram_free
=
[]
local_ngram
=
{}
while
len
(
text_buf
)
>
0
:
# get the first one from the buffer
text
=
text_buf
.
pop
(
0
)
words
,
positions
=
get_words
(
text
)
ngram_free
=
True
# find each max n-grams and check dictionary
for
i
in
range
(
len
(
words
)
-
args
.
ngram_size
+
1
):
for
i
in
range
(
len
(
words
)
-
args
.
max_
ngram_size
+
1
):
check_ngram_free
=
check_and_clean_text
(
args
,
words
[
i
:
\
i
+
args
.
ngram_size
],
ngrams
,
text
,
positions
[
i
],
\
text_buf_ngram_free
,
text_buf
)
i
+
args
.
max_
ngram_size
],
ngrams
,
text
,
positions
[
i
],
\
text_buf_ngram_free
,
text_buf
,
local_ngram
)
# the seq is ngram free? if yes, break
if
not
check_ngram_free
:
...
...
@@ -118,7 +133,7 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
for
ngram_len
,
_
in
ngrams_freq_sorted
:
check_ngram_free
=
check_and_clean_text
(
args
,
words
[
i
:
\
i
+
ngram_len
],
ngrams
,
text
,
positions
[
i
],
\
text_buf_ngram_free
,
text_buf
)
text_buf_ngram_free
,
text_buf
,
local_ngram
)
# same check as above
if
not
check_ngram_free
:
...
...
@@ -130,16 +145,16 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
break
# for the last max n-gram, check all the lower ngrams in it
if
ngram_free
and
len
(
words
)
-
args
.
ngram_size
>
0
:
if
ngram_free
and
len
(
words
)
-
args
.
max_
ngram_size
>
0
:
# get the last words of the lax max ngram
last_seq_words
=
words
[(
len
(
words
)
-
args
.
ngram_size
):
len
(
words
)]
last_seq_start_position
=
len
(
words
)
-
args
.
ngram_size
last_seq_words
=
words
[(
len
(
words
)
-
args
.
max_
ngram_size
):
len
(
words
)]
last_seq_start_position
=
len
(
words
)
-
args
.
max_
ngram_size
# check all n-grams lower than the max
for
pos
,
(
ngram_len
,
_
)
in
enumerate
(
ngrams_freq_sorted
):
# ignore the max ngram as has been considered already
if
ngram_len
==
args
.
ngram_size
:
if
ngram_len
==
args
.
max_
ngram_size
:
continue
# find each ngram of ngram_len in max n-grams and check
...
...
@@ -147,7 +162,7 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
check_ngram_free
=
check_and_clean_text
(
args
,
\
last_seq_words
[
i
:
i
+
ngram_len
],
ngrams
,
text
,
\
positions
[
last_seq_start_position
+
i
],
\
text_buf_ngram_free
,
text_buf
)
text_buf_ngram_free
,
text_buf
,
local_ngram
)
if
not
check_ngram_free
:
ngram_free
=
False
...
...
@@ -157,34 +172,35 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
break
# texts are ngram free
if
ngram_free
:
if
ngram_free
and
not
args
.
get_ngram_freq_only
:
text_buf_ngram_free
.
append
(
text
)
# check if the text has only been trimmed
trimmed
=
0
if
len
(
text_buf_ngram_free
)
==
1
and
len
(
text_buf_ngram_free
[
0
])
<
\
len
(
myjson
[
key
]):
if
not
args
.
get_ngram_freq_only
and
len
(
text_buf_ngram_free
)
==
1
and
\
len
(
text_buf_ngram_free
[
0
])
<
len
(
myjson
[
key
]):
trimmed
=
1
return
text_buf_ngram_free
,
trimmed
return
text_buf_ngram_free
,
trimmed
,
local_ngram
# insert word sequence into dictionary
def
insert_dict
(
words
,
ngrams
,
pos
):
seq
=
" "
.
join
(
words
)
if
seq
not
in
ngrams
:
ngrams
[
seq
]
=
pos
ngrams
[
seq
]
=
0
#ngrams[seq] = pos
# insert each ngram from text into the ngrams dictionary
def
compute_ngrams_insert_dict
(
args
,
text
,
ngrams
):
words
,
positions
=
get_words
(
text
)
if
len
(
words
)
==
0
:
if
len
(
words
)
<
args
.
min_ngram_size
:
return
if
len
(
words
)
<
args
.
ngram_size
:
if
len
(
words
)
<
args
.
max_
ngram_size
:
insert_dict
(
words
,
ngrams
,
positions
[
0
])
for
i
in
range
(
len
(
words
)
-
args
.
ngram_size
+
1
):
insert_dict
(
words
[
i
:
i
+
args
.
ngram_size
],
ngrams
,
positions
[
i
])
for
i
in
range
(
len
(
words
)
-
args
.
max_
ngram_size
+
1
):
insert_dict
(
words
[
i
:
i
+
args
.
max_
ngram_size
],
ngrams
,
positions
[
i
])
# Build ngrams for the lambada dataset
...
...
@@ -203,6 +219,7 @@ def process_task_lambda(args, task_file, ngrams):
# Build ngrams for the dataset of the given task
def
process_task
(
args
,
task_name
,
ngrams
):
print
(
' reading from {} and computing ngrams'
.
format
(
'import datasets'
))
print
(
" Current entities in ngrams {}"
.
format
(
len
(
ngrams
)),
flush
=
True
)
# using validation/test data from datasets
...
...
@@ -253,39 +270,7 @@ def process_task(args, task_name, ngrams):
print
(
" After task {} entities in ngrams {}, added {}"
.
format
(
task_name
,
\
len
(
ngrams
),
len
(
ngrams
)
-
entities_in_ngrams
),
flush
=
True
)
if
__name__
==
'__main__'
:
# we use 13-grams, any text less than 200 characters got removed
# any text splitted more than 10 got removed as well
print
(
'parsing the arguments ...'
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--tasks'
,
nargs
=
'*'
,
required
=
True
,
default
=
None
,
\
help
=
'Tasks to use for deduplication: currently '
' suuport [lambada, squad, natural_questions,'
' triviaqa, webqa, race, drop, coqa, and piqa]'
)
parser
.
add_argument
(
'--lambada-path'
,
type
=
str
,
default
=
None
,
help
=
'Only Lambada task needs the path'
)
parser
.
add_argument
(
'--dedup-dataset'
,
nargs
=
'*'
,
default
=
None
,
help
=
'Dataset to deduplicate with the key to use'
' e.g. cc.json text'
)
parser
.
add_argument
(
'--output'
,
type
=
str
,
default
=
None
,
help
=
'Output file name to save dedup dataset'
)
# Default dedup values
parser
.
add_argument
(
'--ngram-size'
,
type
=
int
,
default
=
13
,
help
=
'Maximum size of ngram to use.'
)
parser
.
add_argument
(
'--filter-text-char-len'
,
type
=
int
,
default
=
200
,
help
=
'Remove any text below this length.'
)
parser
.
add_argument
(
'--splits-count'
,
type
=
int
,
default
=
10
,
help
=
'Remove any documents more than this many splits'
)
parser
.
add_argument
(
'--remove-char-each-side'
,
type
=
int
,
default
=
200
,
help
=
'Maximum size of ngram to use.'
)
args
=
parser
.
parse_args
()
# Build ngrams
ngrams
=
{}
def
compute_tasks_ngrams
(
args
,
ngrams
):
start_time
=
time
.
time
()
for
_
,
task_name
in
enumerate
(
args
.
tasks
):
print
(
'Task: {}'
.
format
(
task_name
),
flush
=
True
)
...
...
@@ -294,10 +279,10 @@ if __name__ == '__main__':
process_task_lambda
(
args
,
args
.
lambada_path
,
ngrams
)
else
:
process_task
(
args
,
task_name
,
ngrams
)
print
(
" Taken time to compute ngrams {:.2f}"
.
format
(
time
.
time
()
-
\
start_time
),
flush
=
True
)
print
(
" Taken time {:.2f}"
.
format
(
time
.
time
()
-
start_time
),
flush
=
True
)
# get the range of the size of the ngrams
def
compute_ngram_freq_sorted
(
args
,
ngrams
):
ngrams_freq
=
{}
for
ngram_key
in
ngrams
.
keys
():
length
=
len
(
ngram_key
.
split
())
...
...
@@ -309,33 +294,74 @@ if __name__ == '__main__':
print
(
" Entities in ngrams {} min_ngram_size {} max_ngram_size {}"
.
format
(
\
len
(
ngrams
),
ngrams_freq_sorted
[
0
][
0
],
ngrams_freq_sorted
[
len
(
\
ngrams_freq_sorted
)
-
1
][
0
]),
flush
=
True
)
return
ngrams_freq_sorted
id_prefix
=
'-'
.
join
(
args
.
tasks
[::
2
])
print
(
'Reading file {} and deduping n-grams'
.
format
(
args
.
dedup_dataset
))
def
get_ngrams_above_threshold
(
args
,
ngrams
,
ngrams_above_threshold
,
\
dedup_file
,
dedup_key
,
ngrams_freq_sorted
):
start_time
=
time
.
time
()
# get the ngrams frequency
args
.
get_ngram_freq_only
=
True
# Open the large file to process in parallel
num_workers
=
40
pool
=
multiprocessing
.
Pool
(
num_workers
)
fin
=
open
(
dedup_file
,
'r'
,
encoding
=
'utf-8'
)
free_ngram_abt_partial
=
partial
(
free_ngram
,
args
=
args
,
key
=
dedup_key
,
\
ngrams
=
ngrams
,
ngrams_freq_sorted
=
ngrams_freq_sorted
)
free_ngrams_abt
=
pool
.
imap
(
free_ngram_abt_partial
,
fin
,
500
)
counter
=
0
for
_
,
_
,
local_ngram
in
free_ngrams_abt
:
counter
+=
1
if
counter
%
1000
==
0
:
print
(
' [compute_stat]> processed {} documents in {:.2f} seconds ...'
.
format
(
counter
,
time
.
time
()
-
start_time
),
flush
=
True
)
for
local_key
in
local_ngram
:
if
local_key
in
ngrams
:
ngrams
[
local_key
]
+=
1
local_ngram
=
{}
print
(
' Time taken to compute statistics {:.2f} seconds'
.
format
(
time
.
time
()
-
\
start_time
),
flush
=
True
)
pool
.
close
()
pool
.
join
()
start_time
=
time
.
time
()
counter_threshold
=
0
# Get ngram above theadhold
for
local_key
,
local_val
in
ngrams
.
items
():
if
ngrams
[
local_key
]
>
args
.
key_threshold
:
print
(
" [threshold] {} {}"
.
format
(
local_key
,
local_val
),
flush
=
True
)
counter_threshold
+=
1
ngrams_above_threshold
[
local_key
]
=
1
print
(
' Ngrams above threshold {}'
.
format
(
counter_threshold
),
flush
=
True
)
fin
.
close
()
if
args
.
output
is
not
None
:
out_f
=
open
(
args
.
output
,
'wb'
)
def
clean_ngrams_above_threshold
(
args
,
ngrams_above_threshold
,
dedup_file
,
\
dedup_key
):
splitted
,
ignored
,
split_mt_thld
,
trimmed_count
=
0
,
0
,
0
,
0
start_time
=
time
.
time
()
# Now actually filter the dataset
args
.
get_ngram_freq_only
=
False
id_prefix
=
'-'
.
join
(
args
.
tasks
[::
2
])
assert
len
(
args
.
dedup_dataset
)
==
2
dedup_file
=
args
.
dedup_dataset
[
0
]
dedup_key
=
args
.
dedup_dataset
[
1
]
# get the range of the size of the ngrams
ngrams_freq_sorted
=
compute_ngram_freq_sorted
(
args
,
ngrams_above_threshold
)
#
Setup multi-
processin
g.
#
Open the large file to
process
in
parallel
num_workers
=
40
fin
=
open
(
dedup_file
,
'r'
,
encoding
=
'utf-8'
)
pool
=
multiprocessing
.
Pool
(
num_workers
)
free_ngram_x
=
partial
(
free_ngram
,
args
=
args
,
key
=
dedup_key
,
ngrams
=
ngrams
,
\
ngrams_freq_sorted
=
ngrams_freq_sorted
)
free_ngrams
=
pool
.
imap
(
free_ngram_x
,
fin
,
25
)
for
text_buf_ngram_free
,
trimmed
in
free_ngrams
:
fin
=
open
(
dedup_file
,
'r'
,
encoding
=
'utf-8'
)
free_ngram_clean_partial
=
partial
(
free_ngram
,
args
=
args
,
key
=
dedup_key
,
\
ngrams
=
ngrams_above_threshold
,
ngrams_freq_sorted
=
ngrams_freq_sorted
)
free_ngrams_clean
=
pool
.
imap
(
free_ngram_clean_partial
,
fin
,
500
)
out_f
=
open
(
args
.
output
,
'wb'
)
counter
=
splitted
=
ignored
=
split_mt_thld
=
trimmed_count
=
0
for
text_buf_ngram_free
,
trimmed
,
_
in
free_ngrams_clean
:
counter
+=
1
try
:
...
...
@@ -361,18 +387,95 @@ if __name__ == '__main__':
out_f
.
write
(
'
\n
'
.
encode
(
'utf-8'
))
if
counter
%
1000
==
0
:
print
(
' [
search
]> processed {} documents in {:.2f} seconds ...'
.
print
(
' [
final
]> processed {} documents in {:.2f} seconds ...'
.
format
(
counter
,
time
.
time
()
-
start_time
),
flush
=
True
)
except
Exception
as
e
:
print
(
'Error:'
,
e
)
if
args
.
output
is
not
None
:
out_f
.
close
()
print
(
' [final]> processed {} documents in {:.2f} seconds ...'
.
format
(
counter
,
time
.
time
()
-
start_time
),
flush
=
True
)
print
(
' Total docs {} splitted {} ignored {} splits > theshold {} trimmed'
\
' {}'
.
format
(
counter
,
splitted
,
ignored
,
split_mt_thld
,
trimmed_count
)
\
,
flush
=
True
)
pool
.
close
()
pool
.
join
()
out_f
.
close
()
fin
.
close
()
print
(
"Deduped file written to: {}"
.
format
(
args
.
output
),
flush
=
True
)
print
(
"Total docs {} splitted {} ignored {} docs with many splits {}"
\
" trimmed {}"
.
format
(
counter
,
splitted
,
ignored
,
split_mt_thld
,
\
trimmed_count
),
flush
=
True
)
if
__name__
==
'__main__'
:
# we use 13-grams, any text less than 200 characters got removed
# any text splitted more than 10 got removed as well
print
(
'parsing the arguments ...'
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--tasks'
,
nargs
=
'*'
,
required
=
True
,
default
=
None
,
\
help
=
'Tasks to use for deduplication: currently '
' suuport [lambada, squad, natural_questions,'
' triviaqa, webqa, race, drop, coqa, and piqa]'
)
parser
.
add_argument
(
'--lambada-path'
,
type
=
str
,
default
=
None
,
help
=
'Only Lambada task needs the path'
)
parser
.
add_argument
(
'--dedup-dataset'
,
nargs
=
'*'
,
default
=
None
,
help
=
'Dataset to deduplicate with the key to use'
' e.g. cc.json text'
)
parser
.
add_argument
(
'--output'
,
type
=
str
,
default
=
None
,
help
=
'Output file name to save dedup dataset'
)
# Default dedup values
parser
.
add_argument
(
'--max-ngram-size'
,
type
=
int
,
default
=
13
,
help
=
'Maximum size of ngram to use.'
)
parser
.
add_argument
(
'--min-ngram-size'
,
type
=
int
,
default
=
8
,
help
=
'Minimum size of ngram to use.'
)
parser
.
add_argument
(
'--filter-text-char-len'
,
type
=
int
,
default
=
200
,
help
=
'Remove any text below this length.'
)
parser
.
add_argument
(
'--key-threshold'
,
type
=
int
,
default
=
10
,
help
=
'Number of keys to consider as threshold'
)
parser
.
add_argument
(
'--save-dictionary'
,
type
=
str
,
default
=
None
,
help
=
'Save the dictionary'
)
parser
.
add_argument
(
'--load-dictionary'
,
type
=
str
,
default
=
None
,
help
=
'Load the dictionary'
)
parser
.
add_argument
(
'--splits-count'
,
type
=
int
,
default
=
10
,
help
=
'Remove any documents more than this many splits'
)
parser
.
add_argument
(
'--remove-char-each-side'
,
type
=
int
,
default
=
200
,
help
=
'Maximum size of ngram to use.'
)
args
=
parser
.
parse_args
()
assert
len
(
args
.
dedup_dataset
)
==
2
dedup_file
=
args
.
dedup_dataset
[
0
]
dedup_key
=
args
.
dedup_dataset
[
1
]
# Setup multi-processing
num_workers
=
40
if
args
.
load_dictionary
is
None
:
# Build ngrams
ngrams
=
{}
compute_tasks_ngrams
(
args
,
ngrams
)
# get the range of the size of the ngrams
ngrams_freq_sorted
=
compute_ngram_freq_sorted
(
args
,
ngrams
)
# get ngram freq from large file in parallel
# get ngrams above threshold
ngrams_above_threshold
=
{}
get_ngrams_above_threshold
(
args
,
ngrams
,
ngrams_above_threshold
,
\
dedup_file
,
dedup_key
,
ngrams_freq_sorted
)
# save the dictionary if needed
if
args
.
save_dictionary
is
not
None
:
with
open
(
args
.
save_dictionary
,
'wb'
)
as
save_dict_handle
:
pickle
.
dump
(
ngrams_above_threshold
,
save_dict_handle
)
else
:
with
open
(
args
.
load_dictionary
,
'rb'
)
as
load_dict_handle
:
ngrams_above_threshold
=
pickle
.
load
(
load_dict_handle
)
# filter the large file
if
args
.
output
is
not
None
:
clean_ngrams_above_threshold
(
args
,
ngrams_above_threshold
,
\
dedup_file
,
dedup_key
)
print
(
'done :-)'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment