Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
6013e23c
"server/vscode:/vscode.git/clone" did not exist on "8511669cb29115bdf0bc2da5328e69d041030996"
Commit
6013e23c
authored
Mar 02, 2021
by
Mostofa Patwary
Browse files
Dedup for other tasks added
parent
b08b5edc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
70 additions
and
24 deletions
+70
-24
tools/openwebtext/filter_ngrams.py
tools/openwebtext/filter_ngrams.py
+70
-24
No files found.
tools/openwebtext/filter_ngrams.py
View file @
6013e23c
...
...
@@ -162,7 +162,7 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
# check if the text has only been trimmed
trimmed
=
0
if
len
(
text_buf_ngram_free
)
==
1
and
len
(
text_buf_ngram_free
[
0
])
==
\
if
len
(
text_buf_ngram_free
)
==
1
and
len
(
text_buf_ngram_free
[
0
])
<
\
len
(
myjson
[
key
]):
trimmed
=
1
...
...
@@ -201,21 +201,57 @@ def process_task_lambda(args, task_file, ngrams):
print
(
" Entities in ngrams {}"
.
format
(
len
(
ngrams
)),
flush
=
True
)
# Build ngrams for the
squad v2 da
tas
et
def
process_task
_squad
(
args
,
ngrams
):
# Build ngrams for the
dataset of the given
tas
k
def
process_task
(
args
,
task_name
,
ngrams
):
print
(
' reading from {} and computing ngrams'
.
format
(
'import datasets'
))
# using squad data from datasets
print
(
" Current entities in ngrams {}"
.
format
(
len
(
ngrams
)),
flush
=
True
)
# using validation/test data from datasets
from
datasets
import
load_dataset
squad_v2
=
load_dataset
(
'squad_v2'
,
split
=
'validation'
)
for
line
in
squad_v2
:
entities_in_ngrams
=
len
(
ngrams
)
# load the dataset
if
task_name
==
'squad'
:
dataset
=
load_dataset
(
'squad_v2'
,
split
=
'validation'
)
elif
task_name
==
'natural_questions'
:
dataset
=
load_dataset
(
'natural_questions'
,
split
=
'validation'
)
elif
task_name
==
'triviaqa'
:
dataset
=
load_dataset
(
'trivia_qa'
,
'unfiltered'
,
split
=
'test'
)
elif
task_name
==
'webqa'
:
dataset
=
load_dataset
(
'web_questions'
,
split
=
'test'
)
elif
task_name
==
'race'
:
dataset
=
load_dataset
(
'race'
,
'all'
,
split
=
'test'
)
elif
task_name
==
'drop'
:
dataset
=
load_dataset
(
'drop'
,
split
=
'validation'
)
elif
task_name
==
'coqa'
:
dataset
=
load_dataset
(
'coqa'
,
split
=
'validation'
)
elif
task_name
==
'piqa'
:
dataset
=
load_dataset
(
'piqa'
,
split
=
'test'
)
else
:
print
(
"Invalid task name: {}"
.
format
(
task_name
),
flush
=
True
)
return
# read the dataset and add to ngrams
for
line
in
dataset
:
try
:
text
=
line
[
'question'
]
compute_ngrams_insert_dict
(
args
,
text
,
ngrams
)
if
task_name
in
[
'squad'
,
'triviaqa'
,
'webqa'
,
'race'
,
'drop'
]:
text
=
line
[
'question'
]
compute_ngrams_insert_dict
(
args
,
text
,
ngrams
)
elif
task_name
==
'natural_questions'
:
text
=
line
[
'question'
][
'text'
]
compute_ngrams_insert_dict
(
args
,
text
,
ngrams
)
elif
task_name
==
'coqa'
:
all_questions
=
line
[
'questions'
]
for
question
in
all_questions
:
compute_ngrams_insert_dict
(
args
,
question
,
ngrams
)
elif
task_name
==
'piqa'
:
text
=
line
[
'goal'
]
compute_ngrams_insert_dict
(
args
,
text
,
ngrams
)
except
Exception
as
e
:
print
(
'Error:'
,
e
)
print
(
" Entities in ngrams {}"
.
format
(
len
(
ngrams
)),
flush
=
True
)
print
(
" After task {} entities in ngrams {}, added {}"
.
format
(
task_name
,
\
len
(
ngrams
),
len
(
ngrams
)
-
entities_in_ngrams
),
flush
=
True
)
if
__name__
==
'__main__'
:
...
...
@@ -227,7 +263,8 @@ if __name__ == '__main__':
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--tasks'
,
nargs
=
'*'
,
required
=
True
,
default
=
None
,
\
help
=
'Tasks to use for deduplication: currently '
' suuport [lambada, squad]'
)
' suuport [lambada, squad, natural_questions,'
' triviaqa, webqa, race, drop, coqa, and piqa]'
)
parser
.
add_argument
(
'--lambada-path'
,
type
=
str
,
default
=
None
,
help
=
'Only Lambada task needs the path'
)
parser
.
add_argument
(
'--dedup-dataset'
,
nargs
=
'*'
,
default
=
None
,
...
...
@@ -249,13 +286,16 @@ if __name__ == '__main__':
# Build ngrams
ngrams
=
{}
start_time
=
time
.
time
()
for
_
,
task_name
in
enumerate
(
args
.
tasks
):
print
(
'Task: {}'
.
format
(
task_name
),
flush
=
True
)
if
task_name
==
'lambada'
:
assert
args
.
lambada_path
is
not
None
process_task_lambda
(
args
,
args
.
lambada_path
,
ngrams
)
if
task_name
==
'squad'
:
process_task_squad
(
args
,
ngrams
)
else
:
process_task
(
args
,
task_name
,
ngrams
)
print
(
" Taken time {:.2f}"
.
format
(
time
.
time
()
-
start_time
),
flush
=
True
)
# get the range of the size of the ngrams
ngrams_freq
=
{}
...
...
@@ -263,8 +303,8 @@ if __name__ == '__main__':
length
=
len
(
ngram_key
.
split
())
ngrams_freq
[
length
]
=
ngrams_freq
[
length
]
+
1
if
length
in
\
ngrams_freq
else
1
ngrams_freq_sorted
=
sorted
(
ngrams_freq
.
items
(),
key
=
lambda
item
:
item
[
1
])
ngrams_freq_sorted
=
sorted
(
ngrams_freq
.
items
(),
key
=
lambda
item
:
item
[
0
])
print
(
" Ngram frequencies: {}"
.
format
(
ngrams_freq_sorted
),
flush
=
True
)
print
(
" Entities in ngrams {} min_ngram_size {} max_ngram_size {}"
.
format
(
\
len
(
ngrams
),
ngrams_freq_sorted
[
0
][
0
],
ngrams_freq_sorted
[
len
(
\
...
...
@@ -276,7 +316,10 @@ if __name__ == '__main__':
counter
=
0
start_time
=
time
.
time
()
out_f
=
open
(
args
.
output
,
'wb'
)
if
args
.
output
is
not
None
:
out_f
=
open
(
args
.
output
,
'wb'
)
splitted
,
ignored
,
split_mt_thld
,
trimmed_count
=
0
,
0
,
0
,
0
assert
len
(
args
.
dedup_dataset
)
==
2
...
...
@@ -299,7 +342,7 @@ if __name__ == '__main__':
trimmed_count
+=
trimmed
if
len
(
text_buf_ngram_free
)
>
1
:
splitted
+=
(
len
(
text_buf_ngram_free
)
-
1
)
splitted
+=
1
if
len
(
text_buf_ngram_free
)
==
0
:
ignored
+=
1
# more than 10 splits ignored
...
...
@@ -307,14 +350,15 @@ if __name__ == '__main__':
text_buf_ngram_free
=
[]
split_mt_thld
+=
1
for
i
in
range
(
len
(
text_buf_ngram_free
)):
split_id_string
=
id_prefix
+
'-{:010d}'
.
format
(
int
(
counter
))
\
+
'-{:010d}'
.
format
(
int
(
i
))
outjson
=
json
.
dumps
({
"text"
:
text_buf_ngram_free
[
i
],
id_prefix
+
"_split_id"
:
split_id_string
},
ensure_ascii
=
False
)
out_f
.
write
(
outjson
.
encode
(
'utf-8'
))
out_f
.
write
(
'
\n
'
.
encode
(
'utf-8'
))
if
args
.
output
is
not
None
:
for
i
in
range
(
len
(
text_buf_ngram_free
)):
split_id_string
=
id_prefix
+
'-{:010d}'
.
format
(
int
(
\
counter
))
+
'-{:010d}'
.
format
(
int
(
i
))
outjson
=
json
.
dumps
({
"text"
:
text_buf_ngram_free
[
i
],
id_prefix
+
"_split_id"
:
split_id_string
},
ensure_ascii
=
False
)
out_f
.
write
(
outjson
.
encode
(
'utf-8'
))
out_f
.
write
(
'
\n
'
.
encode
(
'utf-8'
))
if
counter
%
1000
==
0
:
print
(
' [search]> processed {} documents in {:.2f} seconds ...'
.
...
...
@@ -322,7 +366,9 @@ if __name__ == '__main__':
except
Exception
as
e
:
print
(
'Error:'
,
e
)
out_f
.
close
()
if
args
.
output
is
not
None
:
out_f
.
close
()
fin
.
close
()
print
(
"Deduped file written to: {}"
.
format
(
args
.
output
),
flush
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment