Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
882683dc
Commit
882683dc
authored
Mar 19, 2021
by
Mostofa Patwary
Browse files
Added more feature in train data deduplication
parent
47e7eb80
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
142 additions
and
34 deletions
+142
-34
tools/openwebtext/add_id.py
tools/openwebtext/add_id.py
+48
-0
tools/openwebtext/find_duplicates.py
tools/openwebtext/find_duplicates.py
+89
-29
tools/openwebtext/group_duplicate_url.py
tools/openwebtext/group_duplicate_url.py
+5
-5
No files found.
tools/openwebtext/add_id.py
0 → 100644
View file @
882683dc
import
argparse
import
json
import
os
import
time
if
__name__
==
'__main__'
:
print
(
'parsing the arguments ...'
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--input-file'
,
type
=
str
,
default
=
None
,
help
=
'Input'
\
' json file where id needs to be added'
)
parser
.
add_argument
(
'--output-file'
,
type
=
str
,
default
=
None
,
help
=
\
'Output file name with id'
)
parser
.
add_argument
(
'--id-prefix'
,
type
=
str
,
default
=
None
,
help
=
\
'Id prefix'
)
parser
.
add_argument
(
'--log-interval'
,
type
=
int
,
default
=
100
,
help
=
'Log interval'
)
args
=
parser
.
parse_args
()
print
(
'Adding ids to dataset ...'
)
f_input
=
open
(
args
.
input_file
,
'r'
,
encoding
=
'utf-8'
)
f_output
=
open
(
args
.
output_file
,
'wb'
)
unique_ids
=
1
start_time
=
time
.
time
()
for
row
in
f_input
:
each_row
=
json
.
loads
(
row
)
adlr_id_string
=
args
.
id_prefix
+
'-{:010d}'
.
format
(
int
(
unique_ids
))
each_row
[
'adlr_id'
]
=
adlr_id_string
myjson
=
json
.
dumps
(
each_row
,
ensure_ascii
=
False
)
f_output
.
write
(
myjson
.
encode
(
'utf-8'
))
f_output
.
write
(
'
\n
'
.
encode
(
'utf-8'
))
if
unique_ids
%
args
.
log_interval
==
0
:
print
(
' processed {:9d} documents in {:.2f} seconds ...'
.
format
(
\
unique_ids
,
time
.
time
()
-
start_time
),
flush
=
True
)
unique_ids
+=
1
# Close the file.
f_input
.
close
()
f_output
.
close
()
print
(
'done :-)'
,
flush
=
True
)
tools/openwebtext/find_duplicates.py
View file @
882683dc
...
...
@@ -14,9 +14,11 @@
# limitations under the License.
import
argparse
from
functools
import
partial
import
itertools
import
json
from
lsh
import
cache
,
minhash
import
multiprocessing
import
numpy
as
np
import
time
import
pickle
...
...
@@ -31,11 +33,31 @@ def shingles(text, char_ngram=5):
# This function is adapted from:
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def
jaccard
(
set_a
,
set_b
):
def
jaccard
(
set_a
,
set_b
,
args
):
if
len
(
set_a
)
<
1
or
len
(
set_b
)
<
1
:
return
0.0
intersection
=
set_a
&
set_b
union
=
set_a
|
set_b
return
len
(
intersection
)
/
len
(
union
)
if
args
.
jaccard
==
'min'
:
return
len
(
intersection
)
/
min
(
len
(
set_a
),
len
(
set_b
))
elif
args
.
jaccard
==
'max'
:
return
len
(
intersection
)
/
max
(
len
(
set_a
),
len
(
set_b
))
else
:
return
len
(
intersection
)
/
len
(
union
)
def
compute_fingerprint
(
line
,
key
):
try
:
myjson
=
json
.
loads
(
line
)
url
=
myjson
[
key
]
text
=
myjson
[
'text'
]
fingerprint
=
hasher
.
fingerprint
(
text
)
except
Exception
as
e
:
print
(
'Error:'
,
e
)
return
None
,
None
,
None
,
False
return
url
,
text
,
fingerprint
,
True
if
__name__
==
'__main__'
:
...
...
@@ -55,17 +77,29 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--output'
,
type
=
str
,
default
=
None
,
help
=
'Output file name that consists of all ids'
' with matching similarities'
)
parser
.
add_argument
(
'--jaccard'
,
type
=
str
,
default
=
'union'
,
choices
=
[
'union'
,
'min'
,
'max'
],
help
=
'Jaccard'
\
' similarity computation'
)
parser
.
add_argument
(
'--heuristic-iter'
,
type
=
int
,
default
=
1
,
help
=
'Number of iterations to run the heuristics'
': use -1 for exact'
)
parser
.
add_argument
(
'--num-bands'
,
type
=
int
,
default
=
10
,
help
=
'Number of bands to use in cache'
)
parser
.
add_argument
(
'--num-seeds'
,
type
=
int
,
default
=
100
,
help
=
'Number of seeds to use for minhash. Note that'
' this value should be divisible by num-bands'
)
args
=
parser
.
parse_args
()
print
(
'finding possible duplicate content ...'
)
# set seed and get an array of seeds of 100 integers
np
.
random
.
seed
(
args
.
seed
)
seeds
=
np
.
random
.
randint
(
0
,
1e6
,
size
=
100
)
seeds
=
np
.
random
.
randint
(
0
,
1e6
,
size
=
args
.
num_seeds
)
# initialize minhash and lsh cache
hasher
=
minhash
.
MinHasher
(
seeds
=
seeds
,
char_ngram
=
5
,
hashbytes
=
4
)
lshcache
=
cache
.
Cache
(
bands
=
10
,
hasher
=
hasher
)
lshcache
=
cache
.
Cache
(
num_
bands
=
args
.
num_bands
,
hasher
=
hasher
)
url_doc
=
{}
...
...
@@ -100,22 +134,28 @@ if __name__ == '__main__':
for
input_file
,
key
in
zip
(
args
.
inputs
[::
2
],
args
.
inputs
[
1
::
2
]):
print
(
' document processing {} with key {}'
.
format
(
input_file
,
key
),
flush
=
True
)
# compute fingerprints in parallel
num_workers
=
20
pool
=
multiprocessing
.
Pool
(
num_workers
)
fin
=
open
(
input_file
,
'r'
,
encoding
=
'utf-8'
)
compute_fingerprint_partial
=
partial
(
compute_fingerprint
,
key
=
key
)
compute_fingerprint_iter
=
pool
.
imap
(
compute_fingerprint_partial
,
fin
,
500
)
# traverse all the texts and add fingerprints
with
open
(
input_file
,
'r'
)
as
f_input
:
for
line
in
f_input
:
try
:
myjson
=
json
.
loads
(
line
)
url
=
myjson
[
key
]
text
=
myjson
[
'text'
]
counter
+=
1
url_doc
[
url
]
=
text
lshcache
.
add_fingerprint
(
hasher
.
fingerprint
(
text
),
url
)
except
Exception
as
e
:
print
(
'Error:'
,
e
)
if
counter
%
10000
==
0
:
print
(
' [read]> processed {} documents in {:.2f} '
'seconds ...'
.
format
(
counter
,
time
.
time
()
-
\
start_time
),
flush
=
True
)
for
url
,
text
,
fingerprint
,
flag
in
compute_fingerprint_iter
:
counter
+=
1
if
flag
:
url_doc
[
url
]
=
text
lshcache
.
add_fingerprint
(
fingerprint
,
url
)
if
counter
%
10000
==
0
:
print
(
' [read]> processed {} documents in {:.2f} '
'seconds ...'
.
format
(
counter
,
time
.
time
()
-
\
start_time
),
flush
=
True
)
fin
.
close
()
pool
.
close
()
pool
.
join
()
# Save the fingerprints if needed
if
args
.
save_fingerprints
is
not
None
:
...
...
@@ -133,32 +173,52 @@ if __name__ == '__main__':
f_out
=
open
(
args
.
output
,
'wb'
)
for
b
in
lshcache
.
bins
:
for
bucket_id
in
b
:
if
len
(
b
[
bucket_id
])
>
1
:
items
=
list
(
b
[
bucket_id
])
main_url
=
items
[
0
]
main_dhingles
=
shingles
(
url_doc
[
main_url
])
if
len
(
b
[
bucket_id
])
<=
1
:
continue
bucket_urls
=
b
[
bucket_id
].
copy
()
iteration
=
0
while
len
(
bucket_urls
)
>
1
:
if
args
.
heuristic_iter
!=
-
1
and
\
iteration
==
args
.
heuristic_iter
:
break
items
=
list
(
bucket_urls
)
remove_urls
=
[]
for
i
in
range
(
1
,
len
(
items
)):
main_url
=
items
[
np
.
random
.
randint
(
0
,
len
(
items
))]
main_dhingles
=
shingles
(
url_doc
[
main_url
])
for
i
in
range
(
0
,
len
(
items
)):
counter
+=
1
other_url
=
items
[
i
]
if
other_url
==
main_url
:
continue
other_shingles
=
shingles
(
url_doc
[
other_url
])
try
:
jaccard_sim
=
jaccard
(
main_dhingles
,
other_shingles
)
jaccard_sim
=
jaccard
(
main_dhingles
,
other_shingles
,
args
)
except
Exception
as
e
:
print
(
'Error:'
,
e
)
jaccard_sim
=
0.0
if
jaccard_sim
>
0.5
:
remove_urls
.
append
({
other_url
:
jaccard_sim
})
deduped
+=
1
bucket_urls
.
remove
(
other_url
)
if
counter
%
10000
==
0
:
print
(
' [write]> processed {} documents in {:.2f} '
'seoncds and deduped {} documents ...'
.
format
(
counter
,
time
.
time
()
-
start_time
,
deduped
),
flush
=
True
)
'seoncds and deduped {} documents ...'
.
format
(
counter
,
time
.
time
()
-
start_time
,
deduped
),
flush
=
True
)
bucket_urls
.
remove
(
main_url
)
if
len
(
remove_urls
)
>
0
:
myjson
=
json
.
dumps
({
main_url
:
remove_urls
},
ensure_ascii
=
False
)
ensure_ascii
=
False
)
f_out
.
write
(
myjson
.
encode
(
'utf-8'
))
f_out
.
write
(
'
\n
'
.
encode
(
'utf-8'
))
iteration
+=
1
f_out
.
close
()
print
(
'done :-)'
)
tools/openwebtext/group_duplicate_url.py
View file @
882683dc
...
...
@@ -18,10 +18,6 @@ import time
import
sys
def
is_similar
(
jaccard_similarity
):
return
(
jaccard_similarity
>=
0.7
)
if
__name__
==
'__main__'
:
...
...
@@ -29,6 +25,10 @@ if __name__ == '__main__':
input
=
sys
.
argv
[
1
]
output
=
sys
.
argv
[
2
]
if
len
(
sys
.
argv
)
>
3
:
jaccard_similarity_threshold
=
float
(
sys
.
argv
[
3
])
else
:
jaccard_similarity_threshold
=
0.7
url_to_index
=
{}
index_to_urls
=
[]
...
...
@@ -43,7 +43,7 @@ if __name__ == '__main__':
urls
.
append
(
main_url
)
for
value
in
myjson
[
main_url
]:
for
other_url
,
js
in
value
.
items
():
if
is_similar
(
js
)
:
if
js
>=
jaccard_similarity_threshold
:
urls
.
append
(
other_url
)
current_index
=
-
1
other_indices
=
set
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment