Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
203ca57a
Unverified
Commit
203ca57a
authored
Nov 08, 2022
by
Jiarui Fang
Committed by
GitHub
Nov 08, 2022
Browse files
[example] add GPT
parent
fd2c8d81
Changes
34
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
1975 additions
and
0 deletions
+1975
-0
examples/language/gpt/tools/Megatron/cleanup_dataset.py
examples/language/gpt/tools/Megatron/cleanup_dataset.py
+107
-0
examples/language/gpt/tools/Megatron/cleanup_fix_dataset.py
examples/language/gpt/tools/Megatron/cleanup_fix_dataset.py
+191
-0
examples/language/gpt/tools/Megatron/find_duplicates.py
examples/language/gpt/tools/Megatron/find_duplicates.py
+314
-0
examples/language/gpt/tools/Megatron/gpt2_tokenization.py
examples/language/gpt/tools/Megatron/gpt2_tokenization.py
+305
-0
examples/language/gpt/tools/Megatron/group_duplicate_url.py
examples/language/gpt/tools/Megatron/group_duplicate_url.py
+85
-0
examples/language/gpt/tools/Megatron/remove_group_duplicates.py
...es/language/gpt/tools/Megatron/remove_group_duplicates.py
+64
-0
examples/language/gpt/tools/Megatron/tokenizer.py
examples/language/gpt/tools/Megatron/tokenizer.py
+36
-0
examples/language/gpt/tools/download/download.py
examples/language/gpt/tools/download/download.py
+347
-0
examples/language/gpt/tools/download/download_old.py
examples/language/gpt/tools/download/download_old.py
+58
-0
examples/language/gpt/tools/download/filter.py
examples/language/gpt/tools/download/filter.py
+110
-0
examples/language/gpt/tools/download/get_urls.py
examples/language/gpt/tools/download/get_urls.py
+32
-0
examples/language/gpt/tools/download/scrapers.py
examples/language/gpt/tools/download/scrapers.py
+121
-0
examples/language/gpt/tools/download/utils.py
examples/language/gpt/tools/download/utils.py
+62
-0
examples/language/gpt/train_gpt.py
examples/language/gpt/train_gpt.py
+143
-0
No files found.
examples/language/gpt/tools/Megatron/cleanup_dataset.py
0 → 100644
View file @
203ca57a
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
import
sys
import
time
import
ftfy
import
numpy
as
np
from
langdetect
import
detect
from
tokenizer
import
Tokenizer
MIN_DOCUMENT_LENGTH
=
128
def
print_progress
(
prefix
,
start_time
,
num_docs
,
num_fixed_text
,
num_non_english_docs
,
chars_non_english_docs
,
num_small_docs
,
chars_small_docs
):
string
=
prefix
+
' | '
string
+=
'elapsed time: {:.2f} | '
.
format
(
time
.
time
()
-
start_time
)
string
+=
'documents: {} | '
.
format
(
num_docs
)
string
+=
'fixed text: {} | '
.
format
(
num_fixed_text
)
string
+=
'non-english: {} | '
.
format
(
num_non_english_docs
)
string
+=
'non-english chars: {} | '
.
format
(
chars_non_english_docs
)
string
+=
'small docs: {} | '
.
format
(
num_small_docs
)
string
+=
'small docs chars: {}'
.
format
(
chars_small_docs
)
print
(
string
,
flush
=
True
)
def
filter_corpus
(
filename
,
out_filename
,
print_interval
=
10000
):
print
(
' > filtering {}'
.
format
(
filename
))
tokenizer
=
Tokenizer
(
cache_dir
=
'./cache'
)
num_docs
=
0
num_written_docs
=
0
num_small_docs
=
0
num_fixed_text
=
0
num_non_english_docs
=
0
chars_non_english_docs
=
0
chars_small_docs
=
0
start_time
=
time
.
time
()
with
open
(
out_filename
,
'wb'
)
as
f
:
with
open
(
filename
,
'r'
)
as
fin
:
for
line
in
fin
:
try
:
num_docs
+=
1
myjson
=
json
.
loads
(
line
)
# Fix text
text
=
ftfy
.
fix_text
(
myjson
[
'text'
])
if
text
!=
myjson
[
'text'
]:
num_fixed_text
+=
1
myjson
[
'text'
]
=
text
# Detect language.
if
detect
(
text
)
!=
'en'
:
print
(
'[non-english text]'
,
myjson
)
num_non_english_docs
+=
1
chars_non_english_docs
+=
len
(
text
)
continue
# On average each token is 5 characters so 8 is an
# upper bound.
if
len
(
text
)
<
(
8
*
MIN_DOCUMENT_LENGTH
):
tokens
=
tokenizer
.
tokenize_document
(
text
)
if
len
(
tokens
)
<
MIN_DOCUMENT_LENGTH
:
print
(
'[small document, skipping]:'
,
myjson
)
num_small_docs
+=
1
chars_small_docs
+=
len
(
text
)
continue
myjson
=
json
.
dumps
(
myjson
,
ensure_ascii
=
False
)
f
.
write
(
myjson
.
encode
(
'utf-8'
))
f
.
write
(
'
\n
'
.
encode
(
'utf-8'
))
num_written_docs
+=
1
if
num_docs
%
print_interval
==
0
:
print_progress
(
'[PROGRESS]'
,
start_time
,
num_docs
,
num_fixed_text
,
num_non_english_docs
,
chars_non_english_docs
,
num_small_docs
,
chars_small_docs
)
except
Exception
as
e
:
print
(
' skipping '
,
line
,
e
)
print_progress
(
'[FINAL]'
,
start_time
,
num_docs
,
num_fixed_text
,
num_non_english_docs
,
chars_non_english_docs
,
num_small_docs
,
chars_small_docs
)
if
__name__
==
'__main__'
:
print
(
'building gpt2 dataset ...'
)
input_filename
=
sys
.
argv
[
1
]
output_filename
=
sys
.
argv
[
2
]
print
(
'will be reading {}'
.
format
(
input_filename
))
print
(
'and will write the results to {}'
.
format
(
output_filename
))
filter_corpus
(
input_filename
,
output_filename
)
examples/language/gpt/tools/Megatron/cleanup_fix_dataset.py
0 → 100644
View file @
203ca57a
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Filter and clean documents:
Capable to clean docs with less than 512 characters, less than
256 characters and contains javascript, fix text and dataset specific
cleaning like stories and realnews datasets.
Program arguments have the details.
"""
import
argparse
import
glob
import
json
import
multiprocessing
import
os
import
re
import
time
from
functools
import
partial
from
pathlib
import
Path
import
ftfy
from
langdetect
import
detect
def
process_doc
(
json_line
,
args
):
# Read the line.
document
=
json
.
loads
(
json_line
)
text
=
document
[
'text'
]
output
=
{
'remove_512'
:
False
,
'remove_256_javascript'
:
False
,
\
'remove_512_non_english'
:
False
,
'ftfy_fix_text'
:
False
,
\
'general_cleaning'
:
False
}
try
:
# Remove all docs with less than 512 characters
if
"remove_512"
in
args
.
tasks
:
if
len
(
text
)
<
512
:
output
[
'remove_512'
]
=
True
return
output
,
text
,
document
,
True
# Remove docs if less than 256 character length and contains Javascript
if
"remove_256_javascript"
in
args
.
tasks
:
if
len
(
text
)
<
256
and
'javascript'
in
text
.
lower
():
output
[
'remove_256_javascript'
]
=
True
return
output
,
text
,
document
,
True
# Remove docs < 512 and nonenglish
if
"remove_512_non_english"
in
args
.
tasks
:
if
len
(
text
)
<
512
and
detect
(
text
)
!=
'en'
:
output
[
'remove_512_non_english'
]
=
True
return
output
,
text
,
document
,
True
# Fix the text using ftfy, don't remove the text, hence return False
if
"ftfy_fix_text"
in
args
.
tasks
:
fixed_text
=
ftfy
.
fix_text
(
text
)
output
[
'ftfy_fix_text'
]
=
True
return
output
,
fixed_text
,
document
,
False
# Cleaning extra spaces and newlines
if
"general_cleaning"
in
args
.
tasks
:
cleaned_text
=
re
.
sub
(
r
" +|\b\n+ |\b\n+"
,
" "
,
text
)
#cleaned_text = re.sub(r"\n\n+", "\n\n", text) # used this for Gutenberg dataset
#cleaned_text = re.sub(r"\n", "\n\n", text) # Used this for realnews
# stories datasets
#cleaned_text = re.sub(r" \'", "'", text)
#cleaned_text = re.sub(r" \!", "!", cleaned_text)
#cleaned_text = re.sub(r" \.", ".", cleaned_text)
#cleaned_text = re.sub(r" \?", "?", cleaned_text)
#cleaned_text = re.sub(r" - ", "-", cleaned_text)
##cleaned_text = re.sub(r"\" ", "\"", cleaned_text)
#cleaned_text = re.sub(r" @ ", "@", cleaned_text)
output
[
'general_cleaning'
]
=
True
return
output
,
cleaned_text
,
document
,
False
except
Exception
as
e
:
print
(
'Error: *************************
\n
{}
\n
text: {}'
.
format
(
e
,
\
text
),
flush
=
True
)
return
output
,
text
,
document
,
True
# don't remove
return
output
,
text
,
document
,
False
def
process_set
(
args
,
input_file
,
output_f_cleaned
,
output_f_filtered
):
print
(
' > working on {} ...'
.
format
(
input_file
),
flush
=
True
)
num_docs
=
num_remove_512
=
num_remove_java
=
num_remove_512_non_english
\
=
num_ftfy_fix_text
=
num_general_cleaning
=
0
# Output file and counters.
output_cleaned
=
open
(
output_f_cleaned
,
'wb'
)
output_filtered
=
open
(
output_f_filtered
,
'wb'
)
start_time
=
time
.
time
()
# Setup multi-processing.
num_workers
=
40
fin
=
open
(
input_file
,
'r'
,
encoding
=
'utf-8'
)
pool
=
multiprocessing
.
Pool
(
num_workers
)
process_doc_partial
=
partial
(
process_doc
,
args
=
args
)
processed_docs
=
pool
.
imap
(
process_doc_partial
,
fin
,
500
)
# Process documents.
for
output
,
text
,
document
,
to_filter
in
processed_docs
:
num_docs
+=
1
num_remove_512
+=
1
if
output
[
'remove_512'
]
else
0
num_remove_java
+=
1
if
output
[
'remove_256_javascript'
]
else
0
num_remove_512_non_english
+=
1
if
output
[
'remove_512_non_english'
]
\
else
0
num_ftfy_fix_text
+=
1
if
output
[
'ftfy_fix_text'
]
else
0
num_general_cleaning
+=
1
if
output
[
'general_cleaning'
]
else
0
document
[
'text'
]
=
text
myjson
=
json
.
dumps
(
document
,
ensure_ascii
=
False
)
if
to_filter
:
output_filtered
.
write
(
myjson
.
encode
(
'utf-8'
))
output_filtered
.
write
(
'
\n
'
.
encode
(
'utf-8'
))
else
:
output_cleaned
.
write
(
myjson
.
encode
(
'utf-8'
))
output_cleaned
.
write
(
'
\n
'
.
encode
(
'utf-8'
))
if
num_docs
%
args
.
log_interval
==
0
:
print
(
' processed {:9d} documents in {:.2f} seconds ...'
.
format
(
num_docs
,
time
.
time
()
-
start_time
),
flush
=
True
)
# Close the file.
output_cleaned
.
close
()
output_filtered
.
close
()
fin
.
close
()
# Print stats.
print
(
' >> total docs: {} remove_512 {} remove_256_javascript {} '
\
'remove_512_non_english {} ftfy_fix_text {} general_cleaning {}'
.
\
format
(
num_docs
,
num_remove_512
,
num_remove_java
,
\
num_remove_512_non_english
,
num_ftfy_fix_text
,
\
num_general_cleaning
),
flush
=
True
)
if
__name__
==
'__main__'
:
print
(
'parsing the arguments ...'
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--input-files'
,
nargs
=
'*'
,
required
=
True
,
default
=
\
None
,
help
=
'Input json files that needs to be'
\
' cleaned'
)
parser
.
add_argument
(
'--tasks'
,
nargs
=
'*'
,
required
=
True
,
default
=
None
,
\
help
=
'Tasks to perform on the input files, '
\
'such as remove_512, remove_256_javascript, '
\
'remove_512_non_english, ftfy_fix_text, and '
\
'general_cleaning. 256 or 512 means the number'
\
' of characters.'
)
parser
.
add_argument
(
'--output-path'
,
type
=
str
,
default
=
None
,
help
=
'Directory where the output should go'
)
parser
.
add_argument
(
'--log-interval'
,
type
=
int
,
default
=
100
,
help
=
'Log interval'
)
args
=
parser
.
parse_args
()
print
(
'cleanup dataset ...'
)
for
input_file
in
args
.
input_files
:
input_filename
,
input_filename_ext
=
os
.
path
.
splitext
(
Path
(
input_file
)
\
.
name
)
output_f_cleaned
=
os
.
path
.
join
(
args
.
output_path
,
input_filename
+
\
"_cleaned"
+
input_filename_ext
)
output_f_filtered
=
os
.
path
.
join
(
args
.
output_path
,
input_filename
+
\
"_filtered"
+
input_filename_ext
)
process_set
(
args
,
input_file
,
output_f_cleaned
,
output_f_filtered
)
print
(
'done :-)'
,
flush
=
True
)
examples/language/gpt/tools/Megatron/find_duplicates.py
0 → 100644
View file @
203ca57a
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
itertools
import
json
import
multiprocessing
import
os
import
pickle
import
sys
import
time
from
functools
import
partial
import
numpy
as
np
from
lsh
import
cache
,
minhash
# This function is adapted from:
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def
shingles
(
text
,
char_ngram
=
5
):
return
set
(
text
[
head
:
head
+
char_ngram
]
for
head
in
range
(
0
,
len
(
text
)
-
char_ngram
))
# This function is adapted from:
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def
jaccard
(
set_a
,
set_b
,
args
):
if
len
(
set_a
)
<
1
or
len
(
set_b
)
<
1
:
return
0.0
intersection
=
set_a
&
set_b
union
=
set_a
|
set_b
if
args
.
jaccard
==
'min'
:
return
len
(
intersection
)
/
min
(
len
(
set_a
),
len
(
set_b
))
elif
args
.
jaccard
==
'max'
:
return
len
(
intersection
)
/
max
(
len
(
set_a
),
len
(
set_b
))
else
:
return
len
(
intersection
)
/
len
(
union
)
def
compute_fingerprint
(
line
,
key
):
try
:
myjson
=
json
.
loads
(
line
)
url
=
myjson
[
key
]
text
=
myjson
[
'text'
]
fingerprint
=
hasher
.
fingerprint
(
text
)
except
Exception
as
e
:
print
(
'Error:'
,
e
)
return
None
,
None
,
None
,
False
return
url
,
text
,
fingerprint
,
True
def
url_pairs_to_remove
(
args
,
bucket_urls
,
url_doc
):
remove_urls_list
=
[]
deduped_local
,
counter_local
=
0
,
0
iteration
=
0
while
len
(
bucket_urls
)
>
1
:
if
args
.
heuristic_iter
!=
-
1
and
\
iteration
==
args
.
heuristic_iter
:
break
items
=
list
(
bucket_urls
)
remove_urls
=
[]
main_url
=
items
[
np
.
random
.
randint
(
0
,
len
(
items
))]
main_shingles
=
shingles
(
url_doc
[
main_url
])
for
i
in
range
(
0
,
len
(
items
)):
counter_local
+=
1
other_url
=
items
[
i
]
if
other_url
==
main_url
:
continue
other_shingles
=
shingles
(
url_doc
[
other_url
])
try
:
jaccard_sim
=
jaccard
(
main_shingles
,
other_shingles
,
args
)
except
Exception
as
e
:
print
(
'Error:'
,
e
)
jaccard_sim
=
0.0
if
jaccard_sim
>
0.5
:
remove_urls
.
append
({
other_url
:
jaccard_sim
})
deduped_local
+=
1
bucket_urls
.
remove
(
other_url
)
bucket_urls
.
remove
(
main_url
)
if
len
(
remove_urls
)
>
0
:
remove_urls_list
.
append
({
main_url
:
remove_urls
})
iteration
+=
1
return
remove_urls_list
,
deduped_local
,
counter_local
def
write_remove_urls_list
(
remove_urls_list
,
f_out
):
if
len
(
remove_urls_list
)
>
0
:
for
each_url_remove
in
remove_urls_list
:
myjson
=
json
.
dumps
(
each_url_remove
,
ensure_ascii
=
False
)
f_out
.
write
(
myjson
.
encode
(
'utf-8'
))
f_out
.
write
(
'
\n
'
.
encode
(
'utf-8'
))
def
compute_jaccard
(
each_bin
,
num_bins
,
start_time_local
):
remove_urls_list
=
[]
deduped_local
,
counter_local
,
bucket_local
=
0
,
0
,
0
for
bucket_id
in
each_bin
:
bucket_local
+=
1
if
os
.
getpid
()
%
num_bins
==
0
and
bucket_local
%
100000
==
0
:
print
(
"Counter {}, progress {:.2f} time {:.2f}"
.
\
format
(
bucket_local
,
float
(
bucket_local
)
/
float
(
len
(
each_bin
)),
\
time
.
time
()
-
start_time_local
),
flush
=
True
)
if
len
(
each_bin
[
bucket_id
])
<=
1
:
continue
bucket_urls
=
each_bin
[
bucket_id
].
copy
()
remove_urls_list_sub
,
deduped_local_sub
,
counter_local_sub
=
\
url_pairs_to_remove
(
args
,
bucket_urls
,
url_doc
)
deduped_local
+=
deduped_local_sub
counter_local
+=
counter_local_sub
if
len
(
remove_urls_list_sub
)
>
0
:
remove_urls_list
.
extend
(
remove_urls_list_sub
)
return
remove_urls_list
,
deduped_local
,
counter_local
def
find_pair_urls_parallel
(
args
,
lshcache
,
url_doc
):
start_time
=
time
.
time
()
f_out
=
open
(
args
.
output
,
'wb'
)
deduped
,
counter
=
0
,
0
# compute jaccards of buckets in bin in parallel (parallelism
# limited to # of bins)
num_bins
=
len
(
lshcache
.
bins
)
pool
=
multiprocessing
.
Pool
(
num_bins
)
compute_jaccard_partial
=
partial
(
compute_jaccard
,
num_bins
=
num_bins
,
\
start_time_local
=
start_time
)
# don't need to pass args and url_doc as they are already shared
compute_jaccard_iter
=
pool
.
imap
(
compute_jaccard_partial
,
lshcache
.
bins
)
print
(
"multiprocessing init took {:.2f}"
.
format
(
time
.
time
()
-
start_time
),
\
flush
=
True
)
for
remove_urls_list
,
deduped_local
,
counter_local
in
compute_jaccard_iter
:
deduped
+=
deduped_local
counter
+=
counter_local
write_remove_urls_list
(
remove_urls_list
,
f_out
)
print
(
' [write]> processed {} documents in {:.2f} '
'seconds and deduped {} documents ...'
.
format
(
counter
,
time
.
time
()
\
-
start_time
,
deduped
),
flush
=
True
)
pool
.
close
()
pool
.
join
()
f_out
.
close
()
print
(
' Taken time for jaccard similarities {:.2f} seconds'
.
format
(
\
time
.
time
()
-
start_time
),
flush
=
True
)
def
find_pair_urls_sequential
(
args
,
lshcache
,
url_doc
):
start_time
=
time
.
time
()
f_out
=
open
(
args
.
output
,
'wb'
)
deduped
,
counter
=
0
,
0
for
b
in
lshcache
.
bins
:
for
bucket_id
in
b
:
if
len
(
b
[
bucket_id
])
<=
1
:
continue
bucket_urls
=
b
[
bucket_id
].
copy
()
remove_urls_list_sub
,
deduped_local_sub
,
counter_local_sub
=
\
url_pairs_to_remove
(
args
,
bucket_urls
,
url_doc
)
deduped
+=
deduped_local_sub
counter
+=
counter_local_sub
write_remove_urls_list
(
remove_urls_list_sub
,
f_out
)
if
counter
%
10000
==
0
:
print
(
' [write]> processed {} documents in {:.2f} '
'seconds and deduped {} documents ...'
.
format
(
counter
,
time
.
time
()
-
start_time
,
deduped
),
flush
=
True
)
f_out
.
close
()
print
(
' [write]> processed {} documents in {:.2f} '
'seconds and deduped {} documents ...'
.
format
(
counter
,
time
.
time
()
-
start_time
,
deduped
),
flush
=
True
)
if
__name__
==
'__main__'
:
print
(
'parsing the arguments ...'
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
1234
,
help
=
'Random seed used for python, numpy'
)
parser
.
add_argument
(
'--inputs'
,
nargs
=
'*'
,
default
=
None
,
help
=
\
'Pairwise list of the input files and keys, '
'e.g. --inputs cc.json cc_id news.json news_id'
)
parser
.
add_argument
(
'--load-fingerprints'
,
nargs
=
'*'
,
default
=
None
,
help
=
'Load fingerprints from a list of pickle files,'
' e.g. cc.pkl news.pkl'
)
parser
.
add_argument
(
'--save-fingerprints'
,
type
=
str
,
default
=
None
,
help
=
'Save the fingerprints of the inputs.'
)
parser
.
add_argument
(
'--output'
,
type
=
str
,
default
=
None
,
help
=
'Output file name that consists of all ids'
' with matching similarities'
)
parser
.
add_argument
(
'--jaccard'
,
type
=
str
,
default
=
'union'
,
choices
=
[
'union'
,
'min'
,
'max'
],
help
=
'Jaccard'
\
' similarity computation'
)
parser
.
add_argument
(
'--heuristic-iter'
,
type
=
int
,
default
=
1
,
help
=
'Number of iterations to run the heuristics'
': use -1 for exact'
)
parser
.
add_argument
(
'--num-bands'
,
type
=
int
,
default
=
10
,
help
=
'Number of bands to use in cache'
)
parser
.
add_argument
(
'--num-seeds'
,
type
=
int
,
default
=
100
,
help
=
'Number of seeds to use for minhash. Note that'
' this value should be divisible by num-bands'
)
parser
.
add_argument
(
'--jaccard-parallel'
,
action
=
'store_true'
,
help
=
'Use this to process large number of documents.'
)
args
=
parser
.
parse_args
()
print
(
'finding possible duplicate content ...'
)
# set seed and get an array of seeds of 100 integers
np
.
random
.
seed
(
args
.
seed
)
seeds
=
np
.
random
.
randint
(
0
,
1e6
,
size
=
args
.
num_seeds
)
# initialize minhash and lsh cache
hasher
=
minhash
.
MinHasher
(
seeds
=
seeds
,
char_ngram
=
5
,
hashbytes
=
4
)
lshcache
=
cache
.
Cache
(
num_bands
=
args
.
num_bands
,
hasher
=
hasher
)
url_doc
=
{}
# load fingerprints from pickle file if needed
if
args
.
load_fingerprints
is
not
None
:
for
count_fp
,
fp_file_name
in
enumerate
(
args
.
load_fingerprints
):
print
(
"Loading fingerprints from pickle file {}"
.
format
(
fp_file_name
),
flush
=
True
)
fp
=
open
(
fp_file_name
,
"rb"
)
if
count_fp
==
0
:
# assign directory for the first pkl
lshcache
=
pickle
.
load
(
fp
)
url_doc
=
pickle
.
load
(
fp
)
else
:
# append these to lshcache and url_doc
local_lshcache
=
pickle
.
load
(
fp
)
local_url_doc
=
pickle
.
load
(
fp
)
for
url
in
local_lshcache
.
fingerprints
.
keys
():
url_doc
[
url
]
=
local_url_doc
[
url
]
lshcache
.
add_fingerprint
(
local_lshcache
.
fingerprints
[
url
],
url
)
fp
.
close
()
counter
=
0
start_time
=
time
.
time
()
# compute finger prints of the inputs if any
# input file and the key to use as id
if
args
.
inputs
is
not
None
:
print
(
"Computing fingerprints"
,
flush
=
True
)
assert
len
(
args
.
inputs
)
%
2
==
0
for
input_file
,
key
in
zip
(
args
.
inputs
[::
2
],
args
.
inputs
[
1
::
2
]):
print
(
' document processing {} with key {}'
.
format
(
input_file
,
key
),
flush
=
True
)
# compute fingerprints in parallel
num_workers
=
40
pool
=
multiprocessing
.
Pool
(
num_workers
)
fin
=
open
(
input_file
,
'r'
,
encoding
=
'utf-8'
)
compute_fingerprint_partial
=
partial
(
compute_fingerprint
,
key
=
key
)
compute_fingerprint_iter
=
pool
.
imap
(
compute_fingerprint_partial
,
fin
,
512
)
# traverse all the texts and add fingerprints
for
url
,
text
,
fingerprint
,
flag
in
compute_fingerprint_iter
:
counter
+=
1
if
flag
:
url_doc
[
url
]
=
text
lshcache
.
add_fingerprint
(
fingerprint
,
url
)
if
counter
%
10000
==
0
:
print
(
' [read]> processed {} documents in {:.2f} '
'seconds ...'
.
format
(
counter
,
time
.
time
()
-
\
start_time
),
flush
=
True
)
fin
.
close
()
pool
.
close
()
pool
.
join
()
# Save the fingerprints if needed
if
args
.
save_fingerprints
is
not
None
:
print
(
"Saving fingerprints to pickle file {}"
.
format
(
args
.
save_fingerprints
),
flush
=
True
)
with
open
(
args
.
save_fingerprints
,
'wb'
)
as
f_save
:
pickle
.
dump
(
lshcache
,
f_save
)
pickle
.
dump
(
url_doc
,
f_save
)
# compute jaccard index of the input texts and write to file if needed
if
args
.
output
is
not
None
:
print
(
"Compute jaccard similarity"
,
flush
=
True
)
if
args
.
jaccard_parallel
:
find_pair_urls_parallel
(
args
,
lshcache
,
url_doc
)
else
:
find_pair_urls_sequential
(
args
,
lshcache
,
url_doc
)
print
(
'done :-)'
)
examples/language/gpt/tools/Megatron/gpt2_tokenization.py
0 → 100644
View file @
203ca57a
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
json
import
logging
import
os
import
sys
from
io
import
open
import
regex
as
re
try
:
from
functools
import
lru_cache
except
ImportError
:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE
# tokenizer on python 2 right now.
def
lru_cache
():
return
lambda
func
:
func
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_VOCAB_ARCHIVE_MAP
=
{
'gpt2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json"
,
}
PRETRAINED_MERGES_ARCHIVE_MAP
=
{
'gpt2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt"
,
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
=
{
'gpt2'
:
1024
,
}
VOCAB_NAME
=
'vocab.json'
MERGES_NAME
=
'merges.txt'
SPECIAL_TOKENS_NAME
=
'special_tokens.txt'
@
lru_cache
()
def
bytes_to_unicode
():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
_chr
=
unichr
if
sys
.
version_info
[
0
]
==
2
else
chr
bs
=
list
(
range
(
ord
(
"!"
),
ord
(
"~"
)
+
1
))
+
list
(
range
(
ord
(
"¡"
),
ord
(
"¬"
)
+
1
))
+
\
list
(
range
(
ord
(
"®"
),
ord
(
"ÿ"
)
+
1
))
cs
=
bs
[:]
n
=
0
for
b
in
range
(
2
**
8
):
if
b
not
in
bs
:
bs
.
append
(
b
)
cs
.
append
(
2
**
8
+
n
)
n
+=
1
cs
=
[
_chr
(
n
)
for
n
in
cs
]
return
dict
(
zip
(
bs
,
cs
))
def
get_pairs
(
word
):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs
=
set
()
prev_char
=
word
[
0
]
for
char
in
word
[
1
:]:
pairs
.
add
((
prev_char
,
char
))
prev_char
=
char
return
pairs
class
GPT2Tokenizer
(
object
):
"""
GPT-2 BPE tokenizer. Peculiarities:
- Byte-level BPE
"""
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
merges_file
=
PRETRAINED_MERGES_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
special_tokens_file
=
None
else
:
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
merges_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
MERGES_NAME
)
special_tokens_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
SPECIAL_TOKENS_NAME
)
if
not
os
.
path
.
exists
(
special_tokens_file
):
special_tokens_file
=
None
else
:
logger
.
info
(
"loading special tokens file {}"
.
format
(
special_tokens_file
))
# redirect to the cache, if necessary
try
:
from
cached_path
import
cached_path
resolved_vocab_file
=
cached_path
(
vocab_file
)
resolved_merges_file
=
cached_path
(
merges_file
)
except
EnvironmentError
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
vocab_file
,
merges_file
))
return
None
if
resolved_vocab_file
==
vocab_file
and
resolved_merges_file
==
merges_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
logger
.
info
(
"loading merges file {}"
.
format
(
merges_file
))
else
:
logger
.
info
(
"loading vocabulary file {} from cache at {}"
.
format
(
vocab_file
,
resolved_vocab_file
))
logger
.
info
(
"loading merges file {} from cache at {}"
.
format
(
merges_file
,
resolved_merges_file
))
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len
=
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
[
pretrained_model_name_or_path
]
kwargs
[
'max_len'
]
=
min
(
kwargs
.
get
(
'max_len'
,
int
(
1e12
)),
max_len
)
# Instantiate tokenizer.
if
special_tokens_file
and
'special_tokens'
not
in
kwargs
:
special_tokens
=
open
(
special_tokens_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[:
-
1
]
else
:
special_tokens
=
kwargs
.
pop
(
'special_tokens'
,
[])
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
special_tokens
=
special_tokens
,
*
inputs
,
**
kwargs
)
return
tokenizer
def
__init__
(
self
,
vocab_file
,
merges_file
,
errors
=
'replace'
,
special_tokens
=
None
,
max_len
=
None
):
self
.
max_len
=
max_len
if
max_len
is
not
None
else
int
(
1e12
)
self
.
encoder
=
json
.
load
(
open
(
vocab_file
))
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
errors
=
errors
# how to handle errors in decoding
self
.
byte_encoder
=
bytes_to_unicode
()
self
.
byte_decoder
=
{
v
:
k
for
k
,
v
in
self
.
byte_encoder
.
items
()}
bpe_data
=
open
(
merges_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[
1
:
-
1
]
bpe_merges
=
[
tuple
(
merge
.
split
())
for
merge
in
bpe_data
]
self
.
bpe_ranks
=
dict
(
zip
(
bpe_merges
,
range
(
len
(
bpe_merges
))))
self
.
cache
=
{}
# Should haved added re.IGNORECASE so BPE merges can happen for
# capitalized versions of contractions
self
.
pat
=
re
.
compile
(
r
"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
self
.
special_tokens
=
{}
self
.
special_tokens_decoder
=
{}
self
.
set_special_tokens
(
special_tokens
)
def
__len__
(
self
):
return
len
(
self
.
encoder
)
+
len
(
self
.
special_tokens
)
def
set_special_tokens
(
self
,
special_tokens
):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if
not
special_tokens
:
self
.
special_tokens
=
{}
self
.
special_tokens_decoder
=
{}
return
self
.
special_tokens
=
dict
((
tok
,
len
(
self
.
encoder
)
+
i
)
for
i
,
tok
in
enumerate
(
special_tokens
))
self
.
special_tokens_decoder
=
{
v
:
k
for
k
,
v
in
self
.
special_tokens
.
items
()}
logger
.
info
(
"Special tokens {}"
.
format
(
self
.
special_tokens
))
def
bpe
(
self
,
token
):
if
token
in
self
.
cache
:
return
self
.
cache
[
token
]
word
=
tuple
(
token
)
pairs
=
get_pairs
(
word
)
if
not
pairs
:
return
token
while
True
:
bigram
=
min
(
pairs
,
key
=
lambda
pair
:
self
.
bpe_ranks
.
get
(
pair
,
float
(
'inf'
)))
if
bigram
not
in
self
.
bpe_ranks
:
break
first
,
second
=
bigram
new_word
=
[]
i
=
0
while
i
<
len
(
word
):
try
:
j
=
word
.
index
(
first
,
i
)
new_word
.
extend
(
word
[
i
:
j
])
i
=
j
except
BaseException
:
new_word
.
extend
(
word
[
i
:])
break
if
word
[
i
]
==
first
and
i
<
len
(
word
)
-
1
and
word
[
i
+
1
]
==
second
:
new_word
.
append
(
first
+
second
)
i
+=
2
else
:
new_word
.
append
(
word
[
i
])
i
+=
1
new_word
=
tuple
(
new_word
)
word
=
new_word
if
len
(
word
)
==
1
:
break
else
:
pairs
=
get_pairs
(
word
)
word
=
' '
.
join
(
word
)
self
.
cache
[
token
]
=
word
return
word
def
tokenize
(
self
,
text
):
""" Tokenize a string. """
bpe_tokens
=
[]
for
token
in
re
.
findall
(
self
.
pat
,
text
):
if
sys
.
version_info
[
0
]
==
2
:
token
=
''
.
join
(
self
.
byte_encoder
[
ord
(
b
)]
for
b
in
token
)
else
:
token
=
''
.
join
(
self
.
byte_encoder
[
b
]
for
b
in
token
.
encode
(
'utf-8'
))
bpe_tokens
.
extend
(
bpe_token
for
bpe_token
in
self
.
bpe
(
token
).
split
(
' '
))
return
bpe_tokens
def
convert_tokens_to_ids
(
self
,
tokens
):
""" Converts a sequence of tokens into ids using the vocab. """
ids
=
[]
if
isinstance
(
tokens
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
tokens
,
unicode
)):
if
tokens
in
self
.
special_tokens
:
return
self
.
special_tokens
[
tokens
]
else
:
return
self
.
encoder
.
get
(
tokens
,
0
)
for
token
in
tokens
:
if
token
in
self
.
special_tokens
:
ids
.
append
(
self
.
special_tokens
[
token
])
else
:
ids
.
append
(
self
.
encoder
.
get
(
token
,
0
))
if
len
(
ids
)
>
self
.
max_len
:
logger
.
warning
(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors"
.
format
(
len
(
ids
),
self
.
max_len
))
return
ids
def
convert_ids_to_tokens
(
self
,
ids
,
skip_special_tokens
=
False
):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens
=
[]
for
i
in
ids
:
if
i
in
self
.
special_tokens_decoder
:
if
not
skip_special_tokens
:
tokens
.
append
(
self
.
special_tokens_decoder
[
i
])
else
:
tokens
.
append
(
self
.
decoder
[
i
])
return
tokens
def
encode
(
self
,
text
):
return
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
))
def
decode
(
self
,
tokens
):
text
=
''
.
join
([
self
.
decoder
[
token
]
for
token
in
tokens
])
text
=
bytearray
([
self
.
byte_decoder
[
c
]
for
c
in
text
]).
decode
(
'utf-8'
,
errors
=
self
.
errors
)
return
text
def
save_vocabulary
(
self
,
vocab_path
):
"""Save the tokenizer vocabulary and merge files to a directory."""
if
not
os
.
path
.
isdir
(
vocab_path
):
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
vocab_path
))
return
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
merge_file
=
os
.
path
.
join
(
vocab_path
,
MERGES_NAME
)
special_tokens_file
=
os
.
path
.
join
(
vocab_path
,
SPECIAL_TOKENS_NAME
)
with
open
(
vocab_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
f
.
write
(
json
.
dumps
(
self
.
encoder
,
ensure_ascii
=
False
))
index
=
0
with
open
(
merge_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
writer
.
write
(
u
'#version: 0.2
\n
'
)
for
bpe_tokens
,
token_index
in
sorted
(
self
.
bpe_ranks
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
logger
.
warning
(
"Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
.
format
(
merge_file
))
index
=
token_index
writer
.
write
(
' '
.
join
(
bpe_tokens
)
+
u
'
\n
'
)
index
+=
1
index
=
len
(
self
.
encoder
)
with
open
(
special_tokens_file
,
'w'
,
encoding
=
'utf-8'
)
as
writer
:
for
token
,
token_index
in
sorted
(
self
.
special_tokens
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
logger
.
warning
(
"Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
.
format
(
special_tokens_file
))
index
=
token_index
writer
.
write
(
token
+
u
'
\n
'
)
index
+=
1
return
vocab_file
,
merge_file
,
special_tokens_file
examples/language/gpt/tools/Megatron/group_duplicate_url.py
0 → 100644
View file @
203ca57a
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
sys
import
time
if
__name__
==
'__main__'
:
print
(
'grouping duplicate urls ...'
)
input
=
sys
.
argv
[
1
]
output
=
sys
.
argv
[
2
]
if
len
(
sys
.
argv
)
>
3
:
jaccard_similarity_threshold
=
float
(
sys
.
argv
[
3
])
else
:
jaccard_similarity_threshold
=
0.7
url_to_index
=
{}
index_to_urls
=
[]
counter
=
0
start_time
=
time
.
time
()
with
open
(
input
,
'r'
)
as
f
:
for
line
in
f
:
counter
+=
1
myjson
=
json
.
loads
(
line
)
urls
=
[]
for
main_url
in
myjson
.
keys
():
urls
.
append
(
main_url
)
for
value
in
myjson
[
main_url
]:
for
other_url
,
js
in
value
.
items
():
if
js
>=
jaccard_similarity_threshold
:
urls
.
append
(
other_url
)
current_index
=
-
1
other_indices
=
set
()
for
url
in
urls
:
if
url
in
url_to_index
:
if
current_index
==
-
1
:
current_index
=
url_to_index
[
url
]
elif
current_index
!=
url_to_index
[
url
]:
other_indices
.
add
(
url_to_index
[
url
])
if
current_index
==
-
1
:
current_index
=
len
(
index_to_urls
)
index_to_urls
.
append
(
set
())
for
url
in
urls
:
url_to_index
[
url
]
=
current_index
index_to_urls
[
current_index
].
add
(
url
)
for
index
in
other_indices
:
for
url
in
index_to_urls
[
index
]:
index_to_urls
[
current_index
].
add
(
url
)
url_to_index
[
url
]
=
current_index
index_to_urls
[
index
]
=
None
if
counter
%
100000
==
0
:
print
(
' > processed {} lines in {} seconds ...'
.
format
(
counter
,
time
.
time
()
-
start_time
))
total_remove
=
0
total_remain
=
0
for
urls
in
index_to_urls
:
if
urls
is
not
None
:
if
len
(
urls
)
>
1
:
total_remove
+=
(
len
(
urls
)
-
1
)
total_remain
+=
1
print
(
'out of {} urls, only {} are unique and {} should be removed'
.
format
(
total_remove
+
total_remain
,
total_remain
,
total_remove
))
with
open
(
output
,
'wb'
)
as
f
:
for
i
,
urls
in
enumerate
(
index_to_urls
):
if
urls
is
not
None
:
if
len
(
urls
)
>
1
:
myjson
=
json
.
dumps
({
str
(
i
):
list
(
urls
)},
ensure_ascii
=
False
)
f
.
write
(
myjson
.
encode
(
'utf-8'
))
f
.
write
(
'
\n
'
.
encode
(
'utf-8'
))
examples/language/gpt/tools/Megatron/remove_group_duplicates.py
0 → 100644
View file @
203ca57a
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
sys
import
time
if
__name__
==
'__main__'
:
url_filename
=
sys
.
argv
[
1
]
data_filename
=
sys
.
argv
[
2
]
output_filename
=
sys
.
argv
[
3
]
urls
=
set
()
with
open
(
url_filename
,
'r'
)
as
f
:
for
line
in
f
:
myjson
=
json
.
loads
(
line
)
for
key
in
myjson
:
this_urls
=
myjson
[
key
]
for
i
in
range
(
1
,
len
(
this_urls
)):
urls
.
add
(
this_urls
[
i
])
print
(
'will be removing {} urls'
.
format
(
len
(
urls
)),
flush
=
True
)
written_docs
=
0
removed_docs
=
0
removed_chars
=
0
start_time
=
time
.
time
()
with
open
(
output_filename
,
'wb'
)
as
fout
:
with
open
(
data_filename
,
'r'
)
as
fin
:
for
line
in
fin
:
try
:
myjson
=
json
.
loads
(
line
)
url
=
myjson
[
'url'
]
if
url
in
urls
:
print
(
'removing'
,
myjson
)
removed_docs
+=
1
removed_chars
+=
len
(
myjson
[
'text'
])
continue
myjson
=
json
.
dumps
(
myjson
,
ensure_ascii
=
False
)
fout
.
write
(
myjson
.
encode
(
'utf-8'
))
fout
.
write
(
'
\n
'
.
encode
(
'utf-8'
))
written_docs
+=
1
if
written_docs
%
10000
==
0
:
print
(
' [PROCESSED] time (s): {:.2f} | written: {} '
'| removed: {} (char: {})'
.
format
(
time
.
time
()
-
start_time
,
written_docs
,
removed_docs
,
removed_chars
))
except
Exception
as
e
:
print
(
'[SKIPPING]'
,
line
,
e
)
print
(
' [PROCESSED] time (s): {:.2f} | written: {} '
'| removed: {} (char: {})'
.
format
(
time
.
time
()
-
start_time
,
written_docs
,
removed_docs
,
removed_chars
))
print
(
'done :-)'
)
examples/language/gpt/tools/Megatron/tokenizer.py
0 → 100644
View file @
203ca57a
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
sys
.
path
.
append
(
'..'
)
from
gpt2_tokenization
import
GPT2Tokenizer
class
Tokenizer
:
def
__init__
(
self
,
cache_dir
=
None
):
self
.
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
'gpt2'
,
cache_dir
=
cache_dir
)
self
.
tokenizer
.
max_len
=
int
(
1e12
)
self
.
eod_token
=
self
.
tokenizer
.
encoder
[
'<|endoftext|>'
]
assert
self
.
eod_token
<
65535
,
'vocab size will not fit in uint16'
print
(
'> GPT2 tokenizer with {} vocab size and eod token {} ...'
.
format
(
len
(
self
.
tokenizer
.
encoder
),
self
.
eod_token
))
def
tokenize_document
(
self
,
document
):
tokens
=
self
.
tokenizer
.
encode
(
document
)
tokens
.
append
(
self
.
eod_token
)
return
tokens
examples/language/gpt/tools/download/download.py
0 → 100644
View file @
203ca57a
# Code taken in large part from https://github.com/jcpeterson/openwebtext
from
__future__
import
print_function
import
argparse
import
io
import
json
import
multiprocessing
as
mpl
import
os
import
os.path
as
op
import
sqlite3
import
tarfile
import
time
import
warnings
from
glob
import
glob
from
hashlib
import
sha256
import
tldextract
from
scrapers
import
bs4_scraper
,
newspaper_scraper
,
raw_scraper
# for backward compatibility
from
six.moves.urllib.request
import
urlopen
from
tqdm
import
tqdm
from
utils
import
chunks
,
extract_month
,
linecount
,
mkdir
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"url_file"
,
type
=
str
)
parser
.
add_argument
(
"--save_uncompressed"
,
action
=
"store_true"
,
default
=
False
,
help
=
"whether to save the raw txt files to disk"
,
)
parser
.
add_argument
(
"--output"
,
type
=
str
,
default
=
'raw.json'
,
help
=
"where to save the output json"
,
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"scraped"
,
help
=
"which folder in the working directory to use for output"
,
)
parser
.
add_argument
(
"--n_procs"
,
type
=
int
,
default
=
10
,
help
=
"how many processes (cores) to use for parallel scraping"
,
)
parser
.
add_argument
(
"--timeout"
,
type
=
int
,
default
=-
1
,
help
=
"maximum scrape time for a single URL; -1 means no limit"
,
)
parser
.
add_argument
(
"--max_urls"
,
type
=
int
,
default
=-
1
,
help
=
"maximum # of URLs to scrape; mostly for debugging"
,
)
parser
.
add_argument
(
"--chunk_size"
,
type
=
int
,
default
=
100
,
help
=
"how many URLs to scrape before saving to archive"
,
)
parser
.
add_argument
(
"--scraper"
,
type
=
str
,
default
=
"newspaper"
,
choices
=
[
"raw"
,
"bs4"
,
"newspaper"
],
help
=
"which text/content scraper to use; raw is html"
,
)
parser
.
add_argument
(
"--compress"
,
action
=
"store_true"
,
default
=
False
,
help
=
"whether to output scraped content as compressed archives"
,
)
parser
.
add_argument
(
"--compress_fmt"
,
type
=
str
,
default
=
"xz"
,
choices
=
[
"xz"
,
"bz2"
,
"gz"
],
help
=
"which archive format to use"
,
)
parser
.
add_argument
(
"--scraper_memoize"
,
action
=
"store_true"
,
default
=
False
,
help
=
"whether to use cache for newspaper"
,
)
parser
.
add_argument
(
"--show_warnings"
,
action
=
"store_true"
,
default
=
False
,
help
=
"whether to show warnings in general during scraping"
,
)
parser
.
add_argument
(
"--sqlite_meta"
,
action
=
"store_true"
,
default
=
True
,
help
=
"whether to use sqlite for storing meta. if false, json will be used instead"
,
)
args
=
parser
.
parse_args
()
if
not
args
.
show_warnings
:
# avoid lots of datetime warnings
warnings
.
filterwarnings
(
"ignore"
)
def
load_urls
(
fh
,
max_urls
=-
1
):
url_entries
=
enumerate
(
fh
)
if
max_urls
!=
-
1
:
url_entries
=
list
(
url_entries
)[:
max_urls
]
return
url_entries
def
vet_link
(
link
):
# check if server responds with non-200 status code or link points to a
# non-html file
link_type
,
link_status
=
""
,
-
1
try
:
info
=
urlopen
(
link
)
link_type
=
info
.
headers
[
"Content-Type"
]
link_status
=
info
.
status
except
:
pass
# we want "text/html" only!
is_good_link
=
False
if
"text/html"
in
link_type
and
link_status
==
200
:
is_good_link
=
True
return
is_good_link
,
link_type
def
download
(
url_entry
,
scraper
=
args
.
scraper
,
save_uncompressed
=
args
.
save_uncompressed
,
memoize
=
args
.
scraper_memoize
,
arch_meta
=
not
args
.
sqlite_meta
):
uid
,
url
=
url_entry
url
=
url
.
strip
()
fid
=
"{:07d}-{}"
.
format
(
uid
,
sha256
(
url
.
encode
()).
hexdigest
())
data_dir
=
mkdir
(
op
.
join
(
args
.
output_dir
,
"data"
))
text_fp
=
op
.
join
(
data_dir
,
"{}.txt"
.
format
(
fid
))
if
arch_meta
:
meta_dir
=
mkdir
(
op
.
join
(
args
.
output_dir
,
"meta"
))
meta_fp
=
op
.
join
(
meta_dir
,
"{}.json"
.
format
(
fid
))
# already downloaded!
if
op
.
exists
(
text_fp
):
return
# is_good_link, link_type = vet_link(url)
# if not is_good_link:
# return
if
scraper
==
"bs4"
:
scrape
=
bs4_scraper
elif
scraper
==
"newspaper"
:
scrape
=
newspaper_scraper
elif
scraper
==
"raw"
:
scrape
=
raw_scraper
text
,
meta
=
scrape
(
url
,
memoize
)
ext
=
tldextract
.
extract
(
url
)
domain
=
'.'
.
join
([
x
for
x
in
ext
if
x
])
meta
[
"domain"
]
=
domain
if
text
is
None
or
text
.
strip
()
==
""
:
return
(
""
,
meta
,
fid
,
uid
)
if
save_uncompressed
:
with
open
(
text_fp
,
"w"
)
as
out
:
out
.
write
(
text
)
if
arch_meta
:
with
open
(
meta_fp
,
"w"
)
as
out
:
json
.
dump
(
meta
,
out
)
return
(
text
,
meta
,
fid
,
uid
)
def
archive_chunk
(
cid
,
cdata
,
out_dir
,
fmt
,
arch_meta
):
mkdir
(
out_dir
)
texts
,
metas
,
fids
,
uids
=
zip
(
*
cdata
)
data_tar
=
op
.
join
(
out_dir
,
"{}_data.{}"
.
format
(
cid
,
fmt
))
if
arch_meta
:
meta_tar
=
op
.
join
(
out_dir
,
"{}_meta.{}"
.
format
(
cid
,
fmt
))
tar_fps
,
texts
,
exts
=
[
data_tar
,
meta_tar
],
[
texts
,
metas
],
[
"txt"
,
"json"
]
else
:
tar_fps
,
texts
,
exts
=
[
data_tar
],
[
texts
],
[
"txt"
]
doc_count
=
0
docs_counted
=
False
for
tar_fp
,
txts
,
ext
in
zip
(
tar_fps
,
texts
,
exts
):
with
tarfile
.
open
(
tar_fp
,
"w:"
+
fmt
)
as
tar
:
for
f
,
fid
in
zip
(
txts
,
fids
):
if
f
==
""
:
continue
else
:
if
not
docs_counted
:
doc_count
+=
1
if
ext
==
"json"
:
f
=
json
.
dumps
(
f
)
f
=
f
.
encode
(
"utf-8"
)
t
=
tarfile
.
TarInfo
(
"{}.{}"
.
format
(
fid
,
ext
))
t
.
size
=
len
(
f
)
tar
.
addfile
(
t
,
io
.
BytesIO
(
f
))
docs_counted
=
True
return
doc_count
def
load_state
(
url_file
):
ckptfile
=
url_file
+
'.ckpt'
if
op
.
exists
(
ckptfile
):
with
open
(
ckptfile
)
as
fp
:
r
=
fp
.
read
()
if
r
==
''
:
return
0
else
:
return
int
(
r
)
else
:
return
0
def
save_state
(
url_file
,
cid
):
ckptfile
=
url_file
+
'.ckpt'
with
open
(
ckptfile
,
'w'
)
as
fp
:
fp
.
write
(
str
(
cid
))
def
sqlite_conn
():
conn
=
sqlite3
.
connect
(
'metadata.db'
)
conn
.
execute
(
'''
CREATE TABLE IF NOT EXISTS metadata (
fid char(64) not null primary key,
url varchar(2048) not null,
domain varchar(255) not null,
word_count int null,
elapsed int null,
scraper varchar(255) not null,
success boolean not null
);
'''
)
conn
.
execute
(
'''
CREATE INDEX IF NOT EXISTS ix_meta_url ON metadata(url);
'''
)
conn
.
execute
(
'''
CREATE INDEX IF NOT EXISTS ix_meta_domain ON metadata(domain);
'''
)
return
conn
if
__name__
==
"__main__"
:
if
args
.
sqlite_meta
:
conn
=
sqlite_conn
()
cur
=
conn
.
cursor
()
start_elem
=
load_state
(
args
.
url_file
)
start_chnk
=
start_elem
//
args
.
chunk_size
f_json
=
open
(
args
.
output
,
"w"
)
# URLs we haven't scraped yet (if first run, all URLs in file)
with
open
(
args
.
url_file
)
as
fh
:
url_entries
=
load_urls
(
fh
,
args
.
max_urls
)
pool
=
mpl
.
Pool
(
args
.
n_procs
)
total
=
linecount
(
args
.
url_file
)
//
args
.
chunk_size
print
(
'Total chunks: '
,
total
)
chunk_iterator
=
tqdm
(
enumerate
(
chunks
(
url_entries
,
args
.
chunk_size
,
start_elem
)),
total
=
total
)
# display already-downloaded chunks on progress bar
chunk_iterator
.
update
(
start_chnk
)
# process one "chunk" of args.chunk_size URLs at a time
for
i
,
chunk
in
chunk_iterator
:
cid
=
start_chnk
+
i
+
1
tqdm
.
write
(
"Downloading chunk {}"
.
format
(
cid
))
t1
=
time
.
time
()
if
args
.
timeout
>
0
:
# imap as iterator allows .next() w/ timeout.
# ordered version doesn't seem to work correctly.
# for some reason, you CANNOT track j or chunk[j] in the loop,
# so don't add anything else to the loop below!
# confusingly, chunksize below is unrelated to our chunk_size
chunk_iter
=
pool
.
imap_unordered
(
download
,
chunk
,
chunksize
=
1
)
cdata
=
[]
for
j
in
range
(
len
(
chunk
)):
try
:
result
=
chunk_iter
.
next
(
timeout
=
args
.
timeout
)
cdata
.
append
(
result
)
except
mpl
.
TimeoutError
:
tqdm
.
write
(
" --- Timeout Error --- "
)
else
:
cdata
=
list
(
pool
.
imap
(
download
,
chunk
,
chunksize
=
1
))
tqdm
.
write
(
"{} / {} downloads timed out"
.
format
(
len
(
chunk
)
-
len
(
cdata
),
len
(
chunk
)))
tqdm
.
write
(
"Chunk time: {} seconds"
.
format
(
time
.
time
()
-
t1
))
# write metadata to sqlite
if
args
.
sqlite_meta
:
for
text
,
meta
,
fid
,
_
in
filter
(
lambda
x
:
x
,
cdata
):
if
text
:
params
=
(
fid
,
meta
[
"url"
],
meta
[
"domain"
],
meta
[
"elapsed"
],
meta
[
"word_count"
],
meta
[
"scraper"
],
True
)
else
:
params
=
(
fid
,
meta
[
"url"
],
meta
[
"domain"
],
None
,
None
,
meta
[
"scraper"
],
False
)
cur
.
execute
(
"insert or ignore into metadata (fid, url, domain, elapsed, word_count, scraper, success) values (?, ?, ?, ?, ?, ?, ?)"
,
params
)
conn
.
commit
()
dump_chunk
=
[]
for
text
,
meta
,
fid
,
_
in
filter
(
lambda
x
:
x
,
cdata
):
if
text
:
line_json
=
{
"text"
:
text
,
"url"
:
meta
[
"url"
]}
dump_chunk
.
append
(
json
.
dumps
(
line_json
)
+
'
\n
'
)
f_json
.
writelines
(
dump_chunk
)
# archive and save this chunk to file
if
args
.
compress
:
tqdm
.
write
(
"Compressing..."
)
t2
=
time
.
time
()
count
=
archive_chunk
(
cid
,
cdata
,
args
.
output_dir
,
args
.
compress_fmt
,
not
args
.
sqlite_meta
)
tqdm
.
write
(
"Archive created in {} seconds"
.
format
(
time
.
time
()
-
t2
))
tqdm
.
write
(
"{} out of {} URLs yielded content
\n
"
.
format
(
len
(
list
(
filter
(
lambda
x
:
x
and
x
[
0
],
cdata
))),
len
(
chunk
)))
save_state
(
args
.
url_file
,
cid
*
args
.
chunk_size
)
f_json
.
close
()
print
(
"Done!"
)
examples/language/gpt/tools/download/download_old.py
0 → 100644
View file @
203ca57a
import
hashlib
import
multiprocessing
as
mp
import
os
import
traceback
import
newspaper
import
tldextract
import
tqdm
from
filter
import
should_exclude
hash
=
hashlib
.
sha256
try
:
os
.
mkdir
(
'data'
)
except
FileExistsError
:
pass
def
dl
(
url
):
url
=
url
.
strip
()
if
should_exclude
(
url
):
return
ext
=
tldextract
.
extract
(
url
)
domain
=
'.'
.
join
([
x
for
x
in
ext
if
x
])
fname
=
'data/{}-{}.txt'
.
format
(
domain
,
hash
(
url
.
encode
()).
hexdigest
())
if
os
.
path
.
isfile
(
fname
):
return
# print('Downloading', url)
try
:
article
=
newspaper
.
Article
(
url
,
fetch_images
=
False
)
article
.
download
()
article
.
parse
()
except
newspaper
.
article
.
ArticleException
:
# print('Dead link:', url)
return
# traceback.print_exc()
text
=
article
.
text
if
text
.
strip
()
==
''
:
# print('Empty')
return
with
open
(
fname
,
'w'
)
as
out
:
out
.
write
(
text
)
if
__name__
==
'__main__'
:
p
=
mp
.
Pool
(
100
)
# num of download threads
with
open
(
'urls.txt'
)
as
fh
:
urls
=
list
(
fh
)
list
(
tqdm
.
tqdm
(
p
.
imap
(
dl
,
urls
),
total
=
len
(
urls
)))
print
(
'Done!'
)
examples/language/gpt/tools/download/filter.py
0 → 100644
View file @
203ca57a
import
re
import
tldextract
import
tqdm
from
utils
import
linecount
# https://stackoverflow.com/questions/7160737/python-how-to-validate-a-url-in-python-malformed-or-not
url_regex
=
re
.
compile
(
r
'^(?:http)s?://'
# http:// or https://
r
'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|'
#domain...
r
'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})'
# ...or ip
r
'(?::\d+)?'
# optional port
r
'(?:/?|[/?]\S+)$'
,
re
.
IGNORECASE
)
# domains that aren't scraper friendly. do not include subdomains!
exclude_domains
=
set
([
# image & video hosting sites
'imgur.com'
,
'redd.it'
,
'instagram.com'
,
'discord.gg'
,
'gfycat.com'
,
'giphy.com'
,
'reddituploads.com'
,
'redditmedia.com'
,
'twimg.com'
,
'sli.mg'
,
'magaimg.net'
,
'flickr.com'
,
'imgflip.com'
,
'youtube.com'
,
'youtu.be'
,
'youtubedoubler.com'
,
'vimeo.com'
,
'twitch.tv'
,
'streamable.com'
,
'bandcamp.com'
,
'soundcloud.com'
,
# not scraper friendly
'reddit.com'
,
'gyazo.com'
,
'github.com'
,
'xkcd.com'
,
'twitter.com'
,
'spotify.com'
,
'itunes.apple.com'
,
'facebook.com'
,
'gunprime.com'
,
'strawpoll.me'
,
'voyagefusion.com'
,
'rollingstone.com'
,
'google.com'
,
'timeanddate.com'
,
'walmart.com'
,
'roanoke.com'
,
'spotrac.com'
,
# original paper excluded wikipedia
'wikipedia.org'
,
# lots of top posts for this one
'battleforthenet.com'
,
])
exclude_extensions
=
(
'.png'
,
'.jpg'
,
'.jpeg'
,
'.gif'
,
'.gifv'
,
'.pdf'
,
'.mp4'
,
'.mp3'
,
'.ogv'
,
'.webm'
,
'.doc'
,
'.docx'
,
'.log'
,
'.csv'
,
'.dat'
,
'.iso'
,
'.bin'
,
'.exe'
,
'.apk'
,
'.jar'
,
'.app'
,
'.ppt'
,
'.pps'
,
'.pptx'
,
'.xml'
,
'.gz'
,
'.xz'
,
'.bz2'
,
'.tgz'
,
'.tar'
,
'.zip'
,
'.wma'
,
'.mov'
,
'.wmv'
,
'.3gp'
,
'.svg'
,
'.rar'
,
'.wav'
,
'.avi'
,
'.7z'
)
def
should_exclude
(
url
):
ext
=
tldextract
.
extract
(
url
)
domain
=
'.'
.
join
([
x
for
x
in
ext
if
x
])
basedomain
=
'.'
.
join
(
ext
[
-
2
:])
# Ignore non-URLs
if
len
(
url
)
<=
8
or
' '
in
url
or
re
.
match
(
url_regex
,
url
)
is
None
:
return
True
# Ignore excluded domains
if
basedomain
in
exclude_domains
or
domain
in
exclude_domains
:
return
True
# Ignore case-insensitive matches for excluded extensions
if
url
.
lower
().
split
(
'?'
)[
0
].
endswith
(
exclude_extensions
):
return
True
return
False
if
__name__
==
'__main__'
:
url_file
=
'urls.txt'
filtered_file
=
'urls-filtered.txt'
with
open
(
url_file
)
as
urls
,
open
(
filtered_file
,
'w'
)
as
out
:
url_len
=
linecount
(
url_file
)
print
(
"URL file is"
,
url_len
,
"URLs long."
)
url_set
=
set
()
for
line
in
tqdm
.
tqdm
(
urls
,
total
=
url_len
):
if
len
(
line
.
strip
())
==
0
:
continue
# Skip whitespace-only lines
line
=
line
.
strip
().
split
()[
0
]
# Drop any components following whitespace
if
should_exclude
(
line
):
continue
url_set
.
add
(
line
)
for
line
in
tqdm
.
tqdm
(
url_set
):
out
.
write
(
line
+
'
\n
'
)
examples/language/gpt/tools/download/get_urls.py
0 → 100644
View file @
203ca57a
import
datetime
import
praw
import
psaw
import
tqdm
api
=
psaw
.
PushshiftAPI
()
# all posts until the end of 2017
end_time
=
int
(
datetime
.
datetime
(
2018
,
1
,
1
).
timestamp
())
query
=
api
.
search_submissions
(
before
=
end_time
,
filter
=
[
'url'
,
'score'
],
sort
=
'desc'
,
score
=
'>2'
,
is_self
=
False
,
over_18
=
False
)
with
tqdm
.
tqdm
()
as
pbar
:
# download links from submissions
with
open
(
'urls.txt'
,
'w'
)
as
fh
:
for
subm
in
query
:
url
=
subm
.
url
# weird issue with psaw/pushshift that breaks score=">2"
if
subm
.
score
<
3
:
continue
#print(subm.score)
# pbar.write(str(datetime.datetime.fromtimestamp(subm.created_utc)))
pbar
.
update
(
1
)
fh
.
write
(
url
+
'
\n
'
)
fh
.
flush
()
examples/language/gpt/tools/download/scrapers.py
0 → 100644
View file @
203ca57a
# Code taken in large part from https://github.com/jcpeterson/openwebtext
import
time
import
unicodedata
import
bs4
import
newspaper
from
filter
import
should_exclude
from
htmlmin
import
minify
from
lxml.html.clean
import
Cleaner
def
find_and_filter_tag
(
tag
,
soup
):
"""tag specific filter logic"""
candidates
=
soup
.
find_all
(
tag
)
candidates
=
[
unicodedata
.
normalize
(
"NFKD"
,
x
.
string
)
for
x
in
candidates
if
x
.
string
is
not
None
]
if
tag
==
"p"
:
candidates
=
[
y
.
strip
()
for
y
in
candidates
if
len
(
y
.
split
(
" "
))
>=
4
]
count
=
sum
(
len
(
y
.
split
(
" "
))
for
y
in
candidates
)
else
:
raise
NotImplementedError
return
(
candidates
,
count
)
def
raw_scraper
(
url
,
memoize
):
t1
=
time
.
time
()
if
should_exclude
(
url
):
# heuristic to make downloading faster
return
None
,
{
"url"
:
url
,
"scraper"
:
"raw"
,
}
try
:
cleaner
=
Cleaner
()
cleaner
.
javascript
=
True
cleaner
.
style
=
True
article
=
newspaper
.
Article
(
url
,
fetch_images
=
False
,
memoize_articles
=
memoize
)
article
.
download
()
html
=
minify
(
article
.
html
)
html
=
cleaner
.
clean_html
(
html
)
article
.
parse
()
except
:
return
None
,
{
"url"
:
url
,
"scraper"
:
"raw"
,
}
if
article
.
text
==
""
:
return
None
,
{
"url"
:
url
,
"scraper"
:
"raw"
,
}
metadata
=
{
"url"
:
url
,
"elapsed"
:
time
.
time
()
-
t1
,
"scraper"
:
"raw"
}
return
html
,
metadata
def
newspaper_scraper
(
url
,
memoize
):
t1
=
time
.
time
()
if
should_exclude
(
url
):
# heuristic to make downloading faster
return
None
,
{
"url"
:
url
,
"scraper"
:
"newspaper"
,
}
try
:
article
=
newspaper
.
Article
(
url
,
fetch_images
=
False
,
memoize_articles
=
memoize
)
article
.
download
()
article
.
parse
()
text
=
article
.
text
count
=
len
(
text
.
split
())
except
:
return
None
,
{
"url"
:
url
,
"scraper"
:
"newspaper"
,
}
metadata
=
{
"url"
:
url
,
"word_count"
:
count
,
"elapsed"
:
time
.
time
()
-
t1
,
"scraper"
:
"newspaper"
,
}
return
text
,
metadata
def
bs4_scraper
(
url
,
memoize
):
t1
=
time
.
time
()
if
should_exclude
(
url
):
# heuristic to make downloading faster
return
None
,
{
"url"
:
url
,
"scraper"
:
"bs4"
,
}
try
:
article
=
newspaper
.
Article
(
url
,
fetch_images
=
False
,
memoize_articles
=
memoize
)
article
.
download
()
html
=
article
.
html
soup
=
bs4
.
BeautifulSoup
(
html
,
"lxml"
)
text
,
count
=
find_and_filter_tag
(
"p"
,
soup
)
# DDB: keep text as a single string for consistency with
# newspaper_scraper
text
=
" "
.
join
(
text
)
except
:
return
None
,
{
"url"
:
url
,
"scraper"
:
"bs4"
,
}
metadata
=
{
"url"
:
url
,
"word_count"
:
count
,
"elapsed"
:
time
.
time
()
-
t1
,
"scraper"
:
"bs4"
,
}
return
text
,
metadata
examples/language/gpt/tools/download/utils.py
0 → 100644
View file @
203ca57a
# Code taken in large part from https://github.com/jcpeterson/openwebtext
import
collections
import
os
import
os.path
as
op
import
re
import
tarfile
def
extract_month
(
url_file_name
):
month_re
=
r
"(RS_.*2\d{3}-\d{2})"
month
=
op
.
split
(
url_file_name
)[
-
1
]
month
=
re
.
match
(
month_re
,
month
).
group
()
return
month
def
chunks
(
l
,
n
,
s
=
0
):
"""Yield successive n-sized chunks from l, skipping the first s chunks."""
if
isinstance
(
l
,
collections
.
Iterable
):
chnk
=
[]
for
i
,
elem
in
enumerate
(
l
):
if
i
<
s
:
continue
chnk
.
append
(
elem
)
if
len
(
chnk
)
==
n
:
yield
chnk
chnk
=
[]
if
len
(
chnk
)
!=
0
:
yield
chnk
else
:
for
i
in
range
(
s
,
len
(
l
),
n
):
yield
l
[
i
:
i
+
n
]
def
extract_archive
(
archive_fp
,
outdir
=
"."
):
with
tarfile
.
open
(
archive_fp
,
"r"
)
as
tar
:
tar
.
extractall
(
outdir
)
return
outdir
def
mkdir
(
fp
):
try
:
os
.
makedirs
(
fp
)
except
FileExistsError
:
pass
return
fp
def
linecount
(
filename
):
f
=
open
(
filename
,
'rb'
)
lines
=
0
buf_size
=
1024
*
1024
read_f
=
f
.
raw
.
read
buf
=
read_f
(
buf_size
)
while
buf
:
lines
+=
buf
.
count
(
b
'
\n
'
)
buf
=
read_f
(
buf_size
)
return
lines
examples/language/gpt/train_gpt.py
0 → 100644
View file @
203ca57a
import
contextlib
import
os
import
torch
from
dataset.webtext
import
WebtextDataset
from
titans.loss.lm_loss
import
GPTLMLoss
import
colossalai
import
colossalai.utils
as
utils
from
colossalai
import
nn
as
col_nn
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.nn
import
LinearWarmupLR
from
colossalai.pipeline.pipelinable
import
PipelinableContext
from
colossalai.trainer
import
Trainer
,
hooks
from
colossalai.utils
import
is_using_pp
from
colossalai.utils.timer
import
MultiTimer
from
colossalai.zero.init_ctx
import
ZeroInitContext
def
calc_local_model_size
(
model
:
torch
.
nn
.
Module
):
numel_per_device
=
0
for
p
in
model
.
parameters
():
numel_per_device
+=
p
.
numel
()
return
numel_per_device
def
main
():
parser
=
colossalai
.
get_default_parser
()
parser
.
add_argument
(
'--from_torch'
,
default
=
False
,
action
=
'store_true'
)
args
=
parser
.
parse_args
()
disable_existing_loggers
()
if
args
.
from_torch
:
colossalai
.
launch_from_torch
(
config
=
args
.
config
)
else
:
colossalai
.
launch_from_slurm
(
config
=
args
.
config
,
host
=
args
.
host
,
port
=
29500
,
seed
=
42
)
logger
=
get_dist_logger
()
logger
.
info
(
'Build data loader'
,
ranks
=
[
0
])
train_ds
=
WebtextDataset
(
os
.
environ
[
'DATA'
],
seq_len
=
gpc
.
config
.
SEQ_LEN
)
train_dataloader
=
utils
.
get_dataloader
(
train_ds
,
seed
=
42
,
batch_size
=
gpc
.
config
.
BATCH_SIZE
,
pin_memory
=
True
,
shuffle
=
True
,
drop_last
=
True
)
logger
.
info
(
'Build model'
,
ranks
=
[
0
])
use_pipeline
=
is_using_pp
()
use_interleaved
=
hasattr
(
gpc
.
config
.
model
,
'num_chunks'
)
num_chunks
=
getattr
(
gpc
.
config
.
model
,
'num_chunks'
,
1
)
use_zero3
=
hasattr
(
gpc
.
config
,
'zero'
)
if
not
use_pipeline
:
ctx
=
contextlib
.
nullcontext
()
if
use_zero3
:
ctx
=
ZeroInitContext
(
target_device
=
torch
.
cuda
.
current_device
(),
shard_strategy
=
gpc
.
config
.
zero
.
model_config
.
shard_strategy
,
shard_param
=
True
)
with
ctx
:
model
=
gpc
.
config
.
model
.
pop
(
'type'
)(
**
gpc
.
config
.
model
)
else
:
pipelinable
=
PipelinableContext
()
with
pipelinable
:
model
=
gpc
.
config
.
model
.
pop
(
'type'
)(
**
gpc
.
config
.
model
)
def
mask_function
(
attention_mask
=
None
):
if
attention_mask
is
not
None
:
batch_size
=
gpc
.
config
.
BATCH_SIZE
//
gpc
.
config
.
NUM_MICRO_BATCHES
attention_mask
=
attention_mask
.
view
(
batch_size
,
-
1
)
attention_mask
=
col_nn
.
partition_batch
(
attention_mask
)
attention_mask
=
attention_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
)
attention_mask
=
(
1.0
-
attention_mask
)
*
-
10000.0
return
attention_mask
# GPT2_small exec_seq
# (lyl)TODO: The exec_seq for gpt3 will be added here and to_layer_list should be more friendly to use.
exec_seq
=
[
'embed'
,
mask_function
,
'blocks.0'
,
'blocks.1'
,
'blocks.2'
,
'blocks.3'
,
'blocks.4'
,
'blocks.5'
,
(
mask_function
,
"front"
),
\
'blocks.6'
,
'blocks.7'
,
'blocks.8'
,
'blocks.9'
,
'blocks.10'
,
'blocks.11'
,
'norm'
,
'head'
]
pipelinable
.
to_layer_list
(
exec_seq
)
ctx
=
contextlib
.
nullcontext
()
# (lyl)TODO: Zero context and pipelinable context should be integrated into one context.
if
use_zero3
:
ctx
=
ZeroInitContext
(
target_device
=
torch
.
cuda
.
current_device
(),
shard_strategy
=
gpc
.
config
.
zero
.
model_config
.
shard_strategy
,
shard_param
=
True
)
with
ctx
:
model
=
pipelinable
.
partition
(
num_chunks
,
gpc
.
pipeline_parallel_size
,
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
))
if
use_zero3
:
numel
=
ctx
.
model_numel_tensor
.
item
()
else
:
numel
=
calc_local_model_size
(
model
)
tflop
=
numel
*
gpc
.
config
.
BATCH_SIZE
*
gpc
.
config
.
SEQ_LEN
\
*
gpc
.
get_world_size
(
ParallelMode
.
MODEL
)
*
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
*
8
/
(
1024
**
4
)
criterion
=
getattr
(
gpc
.
config
,
'loss_fn'
,
None
)
if
criterion
is
not
None
:
criterion
=
criterion
.
type
()
else
:
criterion
=
GPTLMLoss
()
logger
.
info
(
'Build optimizer'
,
ranks
=
[
0
])
optimizer
=
gpc
.
config
.
optimizer
.
pop
(
'type'
)(
model
.
parameters
(),
**
gpc
.
config
.
optimizer
)
lr_scheduler
=
LinearWarmupLR
(
optimizer
,
total_steps
=
gpc
.
config
.
NUM_EPOCHS
,
warmup_steps
=
5
)
engine
,
train_dataloader
,
_
,
lr_scheduler
=
colossalai
.
initialize
(
model
,
optimizer
,
criterion
,
train_dataloader
=
train_dataloader
,
lr_scheduler
=
lr_scheduler
)
global_batch_size
=
gpc
.
config
.
BATCH_SIZE
*
\
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
*
getattr
(
gpc
.
config
,
"gradient_accumulation"
,
1
)
logger
.
info
(
f
'Init done, global batch size =
{
global_batch_size
}
'
,
ranks
=
[
0
])
timier
=
MultiTimer
()
trainer
=
Trainer
(
engine
=
engine
,
logger
=
logger
,
timer
=
timier
)
hook_list
=
[
hooks
.
LossHook
(),
hooks
.
LRSchedulerHook
(
lr_scheduler
=
lr_scheduler
,
by_epoch
=
True
),
hooks
.
LogMetricByEpochHook
(
logger
),
hooks
.
ThroughputHook
(
ignored_steps
=
10
,
tflop_per_step
=
tflop
),
hooks
.
LogMetricByStepHook
(),
hooks
.
LogMemoryByEpochHook
(
logger
),
]
trainer
.
fit
(
train_dataloader
=
train_dataloader
,
epochs
=
gpc
.
config
.
NUM_EPOCHS
,
test_interval
=
1
,
hooks
=
hook_list
,
display_progress
=
True
,
return_output_label
=
False
)
if
__name__
==
'__main__'
:
main
()
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment