Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
90c9e3f2
Commit
90c9e3f2
authored
Feb 25, 2021
by
Mostofa Patwary
Browse files
Other tasks dedup added
parent
ebc4a408
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
231 additions
and
72 deletions
+231
-72
tools/openwebtext/filter_ngrams.py
tools/openwebtext/filter_ngrams.py
+231
-72
No files found.
tools/openwebtext/filter_ngrams.py
View file @
90c9e3f2
...
@@ -19,6 +19,7 @@ All split documents with less than 200 characters got filtered. Any document
...
@@ -19,6 +19,7 @@ All split documents with less than 200 characters got filtered. Any document
with more than 10 splits got filtered as well.
with more than 10 splits got filtered as well.
"""
"""
import
argparse
from
functools
import
partial
from
functools
import
partial
import
json
import
json
import
multiprocessing
import
multiprocessing
...
@@ -36,40 +37,20 @@ def get_words(text):
...
@@ -36,40 +37,20 @@ def get_words(text):
positions
.
append
(
match
.
start
())
positions
.
append
(
match
.
start
())
return
words
,
positions
return
words
,
positions
def
free_ngram
(
line
,
ngrams
,
ngram_size
,
filter_text_len
,
# splits the text
splits_count
,
split_window_each_size
):
def
split_text
(
text
,
start_position
,
remove_char_each_side
,
seq
):
# remove all the ngrams
try
:
myjson
=
json
.
loads
(
line
)
text_buf
=
[
myjson
[
'text'
]]
except
Exception
as
e
:
print
(
"Error: {}"
.
format
(
e
),
flush
=
True
)
text_buf
=
[]
text_buf_ngram_free
=
[]
while
len
(
text_buf
)
>
0
:
# get the first one from the buffer
text
=
text_buf
.
pop
(
0
)
words
,
positions
=
get_words
(
text
)
not_ngram_free
=
True
punctuations
=
".!?"
# find n-grams
for
i
in
range
(
len
(
words
)
-
ngram_size
+
1
):
seq
=
" "
.
join
(
words
[
i
:
i
+
ngram_size
])
if
seq
in
ngrams
:
# splits the text
# first part of the text
# first part of the text
pos
=
positions
[
i
]
-
split_window_each_size
punctuations
=
".!?"
pos
=
start_position
-
remove_char_each_side
text_first
=
""
text_first
=
""
while
pos
>
0
and
not
text
[
pos
]
in
punctuations
:
while
pos
>
0
and
not
text
[
pos
]
in
punctuations
:
pos
-=
1
pos
-=
1
if
pos
>
0
:
if
pos
>
0
:
text_first
=
text
[
0
:
pos
+
1
]
text_first
=
text
[
0
:
pos
+
1
]
pos
=
positions
[
i
]
+
split_window_each_size
# add length of seq and remove_char_each_side
pos
=
start_position
+
len
(
seq
)
+
remove_char_each_side
# last part of the text
# last part of the text
text_second
=
""
text_second
=
""
while
pos
<
len
(
text
)
and
not
text
[
pos
]
in
punctuations
:
while
pos
<
len
(
text
)
and
not
text
[
pos
]
in
punctuations
:
...
@@ -77,78 +58,252 @@ def free_ngram(line, ngrams, ngram_size, filter_text_len,
...
@@ -77,78 +58,252 @@ def free_ngram(line, ngrams, ngram_size, filter_text_len,
if
pos
+
1
<
len
(
text
):
if
pos
+
1
<
len
(
text
):
text_second
=
text
[
pos
+
1
:
len
(
text
)]
text_second
=
text
[
pos
+
1
:
len
(
text
)]
return
text_first
,
text_second
def
check_and_clean_text
(
args
,
words
,
ngrams
,
text
,
start_position
,
\
text_buf_ngram_free
,
text_buf
):
seq
=
" "
.
join
(
words
)
if
seq
in
ngrams
:
print
(
" [matched]: {}"
.
format
(
seq
),
flush
=
True
)
# split the text
text_first
,
text_second
=
split_text
(
text
,
start_position
,
\
args
.
remove_char_each_side
,
seq
)
# first part of ngrams free
# first part of ngrams free
if
len
(
text_first
)
>
filter_text_len
:
if
len
(
text_first
)
>
args
.
filter_text_
char_
len
:
text_buf_ngram_free
.
append
(
text_first
)
text_buf_ngram_free
.
append
(
text_first
)
# add second part for further processing
# add second part for further processing
if
len
(
text_second
)
>
filter_text_len
:
if
len
(
text_second
)
>
args
.
filter_text_
char_
len
:
text_buf
.
append
(
text_second
)
text_buf
.
append
(
text_second
)
not_ngram_free
=
False
return
False
# not ngram free
# ngram free
return
True
def
free_ngram
(
line
,
args
,
key
,
ngrams
,
ngrams_freq_sorted
):
# remove all the ngrams
try
:
myjson
=
json
.
loads
(
line
)
text_buf
=
[
myjson
[
key
]]
except
Exception
as
e
:
print
(
"Error: {}"
.
format
(
e
),
flush
=
True
)
text_buf
=
[]
text_buf_ngram_free
=
[]
while
len
(
text_buf
)
>
0
:
# get the first one from the buffer
text
=
text_buf
.
pop
(
0
)
words
,
positions
=
get_words
(
text
)
ngram_free
=
True
# find each max n-grams and check dictionary
for
i
in
range
(
len
(
words
)
-
args
.
ngram_size
+
1
):
check_ngram_free
=
check_and_clean_text
(
args
,
words
[
i
:
\
i
+
args
.
ngram_size
],
ngrams
,
text
,
positions
[
i
],
\
text_buf_ngram_free
,
text_buf
)
# the seq is ngram free? if yes, break
if
not
check_ngram_free
:
ngram_free
=
False
break
# if max ngrams doesn't match, check if any other lower n-grams
# within max ngram macthes
for
ngram_len
,
_
in
ngrams_freq_sorted
:
check_ngram_free
=
check_and_clean_text
(
args
,
words
[
i
:
\
i
+
ngram_len
],
ngrams
,
text
,
positions
[
i
],
\
text_buf_ngram_free
,
text_buf
)
# same check as above
if
not
check_ngram_free
:
ngram_free
=
False
break
# check break from lower than max ngram loop above
if
not
ngram_free
:
break
# for the last max n-gram, check all the lower ngrams in it
if
ngram_free
and
len
(
words
)
-
args
.
ngram_size
>
0
:
# get the last words of the lax max ngram
last_seq_words
=
words
[(
len
(
words
)
-
args
.
ngram_size
):
len
(
words
)]
last_seq_start_position
=
len
(
words
)
-
args
.
ngram_size
# check all n-grams lower than the max
for
pos
,
(
ngram_len
,
_
)
in
enumerate
(
ngrams_freq_sorted
):
# ignore the max ngram as has been considered already
if
ngram_len
==
args
.
ngram_size
:
continue
# find each ngram of ngram_len in max n-grams and check
for
i
in
range
(
len
(
last_seq_words
)
-
ngram_len
+
1
):
check_ngram_free
=
check_and_clean_text
(
args
,
\
last_seq_words
[
i
:
i
+
ngram_len
],
ngrams
,
text
,
\
positions
[
last_seq_start_position
+
i
],
\
text_buf_ngram_free
,
text_buf
)
if
not
check_ngram_free
:
ngram_free
=
False
break
if
not
ngram_free
:
break
break
# text are ngram free
# text
s
are ngram free
if
not_
ngram_free
:
if
ngram_free
:
text_buf_ngram_free
.
append
(
text
)
text_buf_ngram_free
.
append
(
text
)
return
text_buf_ngram_free
# check if the text has only been trimmed
trimmed
=
0
if
len
(
text_buf_ngram_free
)
==
1
and
len
(
text_buf_ngram_free
[
0
])
==
\
len
(
myjson
[
key
]):
trimmed
=
1
return
text_buf_ngram_free
,
trimmed
if
__name__
==
'__main__'
:
# insert word sequence into dictionary
def
insert_dict
(
words
,
ngrams
,
pos
):
seq
=
" "
.
join
(
words
)
if
seq
not
in
ngrams
:
ngrams
[
seq
]
=
pos
print
(
'finding possible duplicate content ...'
)
# insert each ngram from text into the ngrams dictionary
main_file
=
sys
.
argv
[
1
]
# lambada file
def
compute_ngrams_insert_dict
(
args
,
text
,
ngrams
):
dedup_file
=
sys
.
argv
[
2
]
# Book corpus
words
,
positions
=
get_words
(
text
)
output_file
=
sys
.
argv
[
3
]
#Filtered book corpus
if
len
(
words
)
==
0
:
ngrams
=
{}
return
id_prefix
=
"lambada"
# we use 13-grams, any text less than 200 characters got removed
if
len
(
words
)
<
args
.
ngram_size
:
# any text splitted more than 10 got removed as well
insert_dict
(
words
,
ngrams
,
positions
[
0
])
ngram_size
=
13
filter_text_len
=
200
splits_count
=
10
split_window_each_size
=
200
print
(
'Reading file {} and computing ngrams'
.
format
(
main_file
))
for
i
in
range
(
len
(
words
)
-
args
.
ngram_size
+
1
):
with
open
(
main_file
,
'r'
)
as
f
:
insert_dict
(
words
[
i
:
i
+
args
.
ngram_size
],
ngrams
,
positions
[
i
])
# Build ngrams for the lambada dataset
def
process_task_lambda
(
args
,
task_file
,
ngrams
):
print
(
' reading from {} and computing ngrams'
.
format
(
task_file
))
with
open
(
task_file
,
'r'
)
as
f
:
for
line
in
f
:
for
line
in
f
:
try
:
try
:
myjson
=
json
.
loads
(
line
)
myjson
=
json
.
loads
(
line
)
words
,
positions
=
get_words
(
myjson
[
'text'
])
text
=
myjson
[
'text'
]
for
i
in
range
(
len
(
words
)
-
ngram_size
+
1
):
compute_ngrams_insert_dict
(
args
,
text
,
ngrams
)
seq
=
" "
.
join
(
words
[
i
:
i
+
ngram_size
])
except
Exception
as
e
:
if
seq
not
in
ngrams
:
print
(
'Error:'
,
e
)
ngrams
[
seq
]
=
positions
[
i
]
print
(
" Entities in ngrams {}"
.
format
(
len
(
ngrams
)),
flush
=
True
)
# Build ngrams for the squad v2 dataset
def
process_task_squad
(
args
,
ngrams
):
print
(
' reading from {} and computing ngrams'
.
format
(
'import datasets'
))
# using squad data from datasets
from
datasets
import
load_dataset
squad_v2
=
load_dataset
(
'squad_v2'
,
split
=
'validation'
)
for
line
in
squad_v2
:
try
:
text
=
line
[
'question'
]
compute_ngrams_insert_dict
(
args
,
text
,
ngrams
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
'Error:'
,
e
)
print
(
'Error:'
,
e
)
print
(
"ngrams size {}"
.
format
(
len
(
ngrams
)))
print
(
" Entities in ngrams {}"
.
format
(
len
(
ngrams
)),
flush
=
True
)
if
__name__
==
'__main__'
:
# we use 13-grams, any text less than 200 characters got removed
# any text splitted more than 10 got removed as well
print
(
'parsing the arguments ...'
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--tasks'
,
nargs
=
'*'
,
required
=
True
,
default
=
None
,
\
help
=
'Tasks to use for deduplication: currently '
' suuport [lambada, squad]'
)
parser
.
add_argument
(
'--lambada-path'
,
type
=
str
,
default
=
None
,
help
=
'Only Lambada task needs the path'
)
parser
.
add_argument
(
'--dedup-dataset'
,
nargs
=
'*'
,
default
=
None
,
help
=
'Dataset to deduplicate with the key to use'
' e.g. cc.json text'
)
parser
.
add_argument
(
'--output'
,
type
=
str
,
default
=
None
,
help
=
'Output file name to save dedup dataset'
)
# Default dedup values
parser
.
add_argument
(
'--ngram-size'
,
type
=
int
,
default
=
13
,
help
=
'Maximum size of ngram to use.'
)
parser
.
add_argument
(
'--filter-text-char-len'
,
type
=
int
,
default
=
200
,
help
=
'Remove any text below this length.'
)
parser
.
add_argument
(
'--splits-count'
,
type
=
int
,
default
=
10
,
help
=
'Remove any documents more than this many splits'
)
parser
.
add_argument
(
'--remove-char-each-side'
,
type
=
int
,
default
=
200
,
help
=
'Maximum size of ngram to use.'
)
args
=
parser
.
parse_args
()
# Build ngrams
ngrams
=
{}
for
_
,
task_name
in
enumerate
(
args
.
tasks
):
print
(
'Task: {}'
.
format
(
task_name
),
flush
=
True
)
if
task_name
==
'lambada'
:
assert
args
.
lambada_path
is
not
None
process_task_lambda
(
args
,
args
.
lambada_path
,
ngrams
)
if
task_name
==
'squad'
:
process_task_squad
(
args
,
ngrams
)
# get the range of the size of the ngrams
ngrams_freq
=
{}
for
ngram_key
in
ngrams
.
keys
():
length
=
len
(
ngram_key
.
split
())
ngrams_freq
[
length
]
=
ngrams_freq
[
length
]
+
1
if
length
in
\
ngrams_freq
else
1
ngrams_freq_sorted
=
sorted
(
ngrams_freq
.
items
(),
key
=
lambda
item
:
item
[
1
])
print
(
" Ngram frequencies: {}"
.
format
(
ngrams_freq_sorted
),
flush
=
True
)
print
(
" Entities in ngrams {} min_ngram_size {} max_ngram_size {}"
.
format
(
\
len
(
ngrams
),
ngrams_freq_sorted
[
0
][
0
],
ngrams_freq_sorted
[
len
(
\
ngrams_freq_sorted
)
-
1
][
0
]),
flush
=
True
)
id_prefix
=
'-'
.
join
(
args
.
tasks
[::
2
])
print
(
'Reading file {} and deduping n-grams'
.
format
(
args
.
dedup_dataset
))
print
(
'Reading file {} and deduping n-grams'
.
format
(
dedup_file
))
counter
=
0
counter
=
0
start_time
=
time
.
time
()
start_time
=
time
.
time
()
out_f
=
open
(
output_file
,
'wb'
)
out_f
=
open
(
args
.
output
,
'wb'
)
splitted
,
ignored
,
split_mt_thld
=
0
,
0
,
0
splitted
,
ignored
,
split_mt_thld
,
trimmed_count
=
0
,
0
,
0
,
0
assert
len
(
args
.
dedup_dataset
)
==
2
dedup_file
=
args
.
dedup_dataset
[
0
]
dedup_key
=
args
.
dedup_dataset
[
1
]
# Setup multi-processing.
# Setup multi-processing.
num_workers
=
40
num_workers
=
1
#
40
fin
=
open
(
dedup_file
,
'r'
,
encoding
=
'utf-8'
)
fin
=
open
(
dedup_file
,
'r'
,
encoding
=
'utf-8'
)
pool
=
multiprocessing
.
Pool
(
num_workers
)
pool
=
multiprocessing
.
Pool
(
num_workers
)
free_ngram_x
=
partial
(
free_ngram
,
ngrams
=
ngrams
,
ngram_size
=
ngram
_size
,
free_ngram_x
=
partial
(
free_ngram
,
args
=
args
,
key
=
dedup_key
,
ngrams
=
ngram
s
,
\
filter_text_len
=
filter_text_len
,
splits_count
=
splits_count
,
ngrams_freq_sorted
=
ngrams_freq_sorted
)
split_window_each_size
=
split_window_each_size
)
free_ngrams
=
pool
.
imap
(
free_ngram_x
,
fin
,
25
)
free_ngrams
=
pool
.
imap
(
free_ngram_x
,
fin
,
25
)
for
text_buf_ngram_free
in
free_ngrams
:
for
text_buf_ngram_free
,
trimmed
in
free_ngrams
:
counter
+=
1
counter
+=
1
try
:
try
:
trimmed_count
+=
trimmed
if
len
(
text_buf_ngram_free
)
>
1
:
if
len
(
text_buf_ngram_free
)
>
1
:
splitted
+=
(
len
(
text_buf_ngram_free
)
-
1
)
splitted
+=
(
len
(
text_buf_ngram_free
)
-
1
)
if
len
(
text_buf_ngram_free
)
==
0
:
if
len
(
text_buf_ngram_free
)
==
0
:
ignored
+=
1
ignored
+=
1
# more than 10 splits ignored
# more than 10 splits ignored
if
len
(
text_buf_ngram_free
)
>
splits_count
:
if
len
(
text_buf_ngram_free
)
>
args
.
splits_count
:
text_buf_ngram_free
=
[]
text_buf_ngram_free
=
[]
split_mt_thld
+=
1
split_mt_thld
+=
1
...
@@ -167,7 +322,11 @@ if __name__ == '__main__':
...
@@ -167,7 +322,11 @@ if __name__ == '__main__':
except
Exception
as
e
:
except
Exception
as
e
:
print
(
'Error:'
,
e
)
print
(
'Error:'
,
e
)
print
(
"Deduped file written to: {}"
.
format
(
output_file
),
flush
=
True
)
out_f
.
close
()
print
(
"Total docs {} splitted {} ignored {} docs with many splits {}"
.
\
fin
.
close
()
format
(
counter
,
splitted
,
ignored
,
split_mt_thld
),
flush
=
True
)
print
(
"Deduped file written to: {}"
.
format
(
args
.
output
),
flush
=
True
)
print
(
"Total docs {} splitted {} ignored {} docs with many splits {}"
\
" trimmed {}"
.
format
(
counter
,
splitted
,
ignored
,
split_mt_thld
,
\
trimmed_count
),
flush
=
True
)
print
(
'done :-)'
)
print
(
'done :-)'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment