Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a00b3c79
Commit
a00b3c79
authored
Jun 28, 2020
by
Neel Kant
Browse files
Bring in indexing code to what was previously ict-merge
parents
2a3b445d
8f3f338a
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
167 additions
and
18 deletions
+167
-18
.gitignore
.gitignore
+11
-1
megatron/arguments.py
megatron/arguments.py
+2
-0
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+5
-2
megatron/data/realm_dataset_utils.py
megatron/data/realm_dataset_utils.py
+48
-3
megatron/tokenizer/tokenizer.py
megatron/tokenizer/tokenizer.py
+5
-5
megatron/training.py
megatron/training.py
+7
-0
pretrain_bert_ict.py
pretrain_bert_ict.py
+2
-2
tools/preprocess_data.py
tools/preprocess_data.py
+87
-5
No files found.
.gitignore
View file @
a00b3c79
...
...
@@ -3,4 +3,14 @@ __pycache__
# Distribution / packaging
build/
dist/
*.egg-info/
\ No newline at end of file
*.egg-info/
# added by neel
*.npy
*.bin
*.idx
*.pkl
raw_*
run_*
realm_*
megatron/arguments.py
View file @
a00b3c79
...
...
@@ -386,6 +386,8 @@ def _add_data_args(parser):
help
=
'Mask loss for the end of document tokens.'
)
group
.
add_argument
(
'--query-in-block-prob'
,
type
=
float
,
default
=
0.1
,
help
=
'Probability of keeping query in block for ICT dataset'
)
group
.
add_argument
(
'--faiss-use-gpu'
,
action
=
'store_true'
,
help
=
'Whether create the FaissMIPSIndex on GPU'
)
return
parser
...
...
megatron/data/realm_dataset.py
View file @
a00b3c79
import
collections
import
itertools
import
random
...
...
@@ -39,7 +40,9 @@ class ICTDataset(Dataset):
def
__getitem__
(
self
,
idx
):
"""Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
start_idx
,
end_idx
,
doc_idx
,
block_idx
=
self
.
samples_mapping
[
idx
]
sample_data
=
self
.
samples_mapping
[
idx
]
start_idx
,
end_idx
,
doc_idx
,
block_idx
=
sample_data
.
as_tuple
()
if
self
.
use_titles
:
title
=
self
.
title_dataset
[
int
(
doc_idx
)]
title_pad_offset
=
3
+
len
(
title
)
...
...
@@ -65,7 +68,7 @@ class ICTDataset(Dataset):
query_tokens
,
query_pad_mask
=
self
.
concat_and_pad_tokens
(
query
)
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
block_data
=
np
.
array
([
start_idx
,
end_idx
,
doc_idx
,
block_idx
]).
astype
(
np
.
int64
)
block_data
=
sample_data
.
as_array
(
)
sample
=
{
'query_tokens'
:
query_tokens
,
...
...
megatron/data/realm_dataset_utils.py
View file @
a00b3c79
...
...
@@ -5,6 +5,8 @@ import numpy as np
import
torch
from
megatron
import
mpu
,
print_rank_0
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
pad_and_convert_to_numpy
from
megatron
import
get_args
,
get_tokenizer
,
print_rank_0
,
mpu
def
join_str_list
(
str_list
):
...
...
@@ -18,10 +20,47 @@ def join_str_list(str_list):
return
result
class
BlockSampleData
(
object
):
"""A struct for fully describing a fixed-size block of data as used in REALM
:param start_idx: for first sentence of the block
:param end_idx: for last sentence of the block (may be partially truncated in sample construction)
:param doc_idx: the index of the document from which the block comes in the original indexed dataset
:param block_idx: a unique integer identifier given to every block.
"""
def
__init__
(
self
,
start_idx
,
end_idx
,
doc_idx
,
block_idx
):
self
.
start_idx
=
start_idx
self
.
end_idx
=
end_idx
self
.
doc_idx
=
doc_idx
self
.
block_idx
=
block_idx
def
as_array
(
self
):
return
np
.
array
([
self
.
start_idx
,
self
.
end_idx
,
self
.
doc_idx
,
self
.
block_idx
]).
astype
(
np
.
int64
)
def
as_tuple
(
self
):
return
self
.
start_idx
,
self
.
end_idx
,
self
.
doc_idx
,
self
.
block_idx
class
BlockSamplesMapping
(
object
):
def
__init__
(
self
,
mapping_array
):
# make sure that the array is compatible with BlockSampleData
assert
mapping_array
.
shape
[
1
]
==
4
self
.
mapping_array
=
mapping_array
def
__getitem__
(
self
,
idx
):
"""Get the data associated with a particular sample."""
sample_data
=
BlockSamplesData
(
*
self
.
mapping_array
[
idx
])
return
sample_data
def
get_block_samples_mapping
(
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
seed
,
name
,
use_one_sent_docs
=
False
):
"""Get samples mapping for a dataset over fixed size blocks. This function also requires
a dataset of the titles for the source documents since their lengths must be taken into account."""
a dataset of the titles for the source documents since their lengths must be taken into account.
:return: samples_mapping (BlockSamplesMapping)
"""
if
not
num_epochs
:
if
not
max_num_samples
:
raise
ValueError
(
"Need to specify either max_num_samples "
...
...
@@ -58,19 +97,24 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
start_time
=
time
.
time
()
print_rank_0
(
' > building samples index mapping for {} ...'
.
format
(
name
))
# compile/bind the C++ helper code
from
megatron.data.dataset_utils
import
compile_helper
compile_helper
()
from
megatron.data
import
helpers
samples_
mapping
=
helpers
.
build_blocks_mapping
(
mapping
_array
=
helpers
.
build_blocks_mapping
(
block_dataset
.
doc_idx
,
block_dataset
.
sizes
,
title_dataset
.
sizes
,
num_epochs
,
max_num_samples
,
max_seq_length
-
3
,
# account for added tokens
max_seq_length
-
3
,
# account for added tokens
seed
,
verbose
,
use_one_sent_docs
)
samples_mapping
=
BlockSamplesMapping
(
mapping_array
)
print_rank_0
(
' > done building samples index mapping'
)
np
.
save
(
indexmap_filename
,
samples_mapping
,
allow_pickle
=
True
)
print_rank_0
(
' > saved the index mapping in {}'
.
format
(
...
...
@@ -79,6 +123,7 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
print_rank_0
(
' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
...
...
megatron/tokenizer/tokenizer.py
View file @
a00b3c79
...
...
@@ -31,11 +31,11 @@ def build_tokenizer(args):
# Select and instantiate the tokenizer.
assert
args
.
vocab_file
is
not
None
if
args
.
tokenizer_type
==
'BertWordPieceLowerCase'
:
tokenizer
=
_
BertWordPieceTokenizer
(
vocab_file
=
args
.
vocab_file
,
lower_case
=
True
)
tokenizer
=
BertWordPieceTokenizer
(
vocab_file
=
args
.
vocab_file
,
lower_case
=
True
)
elif
args
.
tokenizer_type
==
'BertWordPieceCase'
:
tokenizer
=
_
BertWordPieceTokenizer
(
vocab_file
=
args
.
vocab_file
,
lower_case
=
False
)
tokenizer
=
BertWordPieceTokenizer
(
vocab_file
=
args
.
vocab_file
,
lower_case
=
False
)
elif
args
.
tokenizer_type
==
'GPT2BPETokenizer'
:
assert
args
.
merge_file
is
not
None
tokenizer
=
_GPT2BPETokenizer
(
args
.
vocab_file
,
args
.
merge_file
)
...
...
@@ -124,7 +124,7 @@ class AbstractTokenizer(ABC):
'tokenizer'
.
format
(
self
.
name
))
class
_
BertWordPieceTokenizer
(
AbstractTokenizer
):
class
BertWordPieceTokenizer
(
AbstractTokenizer
):
"""Original BERT wordpiece tokenizer."""
def
__init__
(
self
,
vocab_file
,
lower_case
=
True
):
...
...
megatron/training.py
View file @
a00b3c79
...
...
@@ -73,6 +73,11 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
args
=
get_args
()
timers
=
get_timers
()
if
args
.
rank
==
0
and
args
.
cased_data_path
is
not
None
:
import
stanza
stanza
.
download
(
'en'
,
processors
=
{
'ner'
:
'conll03'
},
dir
=
'stanza'
)
# Model, optimizer, and learning rate.
timers
(
'model and optimizer'
).
start
()
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
)
...
...
@@ -227,6 +232,8 @@ def setup_model_and_optimizer(model_provider_func):
def
backward_step
(
optimizer
,
model
,
loss
):
"""Backward step."""
# if args.rank == 0:
# torch.save(lick)
args
=
get_args
()
timers
=
get_timers
()
...
...
pretrain_bert_ict.py
View file @
a00b3c79
...
...
@@ -30,7 +30,7 @@ from megatron.utils import reduce_losses
num_batches
=
0
def
general_model_provider
(
only_query_model
=
False
,
only_block_model
=
False
):
def
general_
ict_
model_provider
(
only_query_model
=
False
,
only_block_model
=
False
):
"""Build the model."""
args
=
get_args
()
assert
args
.
ict_head_size
is
not
None
,
\
...
...
@@ -53,7 +53,7 @@ def general_model_provider(only_query_model=False, only_block_model=False):
def
model_provider
():
return
general_model_provider
(
False
,
False
)
return
general_
ict_
model_provider
(
False
,
False
)
def
get_batch
(
data_iterator
):
...
...
tools/preprocess_data.py
View file @
a00b3c79
...
...
@@ -24,6 +24,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os
.
path
.
pardir
)))
import
time
import
numpy
as
np
import
torch
try
:
import
nltk
...
...
@@ -31,8 +32,11 @@ try:
except
ImportError
:
nltk_available
=
False
from
megatron.tokenizer
import
build_tokenizer
from
megatron.data
import
indexed_dataset
from
megatron.data.realm_dataset_utils
import
id_to_str_pos_map
# https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer
...
...
@@ -75,6 +79,14 @@ class Encoder(object):
else
:
Encoder
.
splitter
=
IdentitySplitter
()
try
:
import
spacy
print
(
"> Loading spacy"
)
Encoder
.
spacy
=
spacy
.
load
(
'en_core_web_lg'
)
print
(
">> Finished loading spacy"
)
except
:
Encoder
.
spacy
=
None
def
encode
(
self
,
json_line
):
data
=
json
.
loads
(
json_line
)
ids
=
{}
...
...
@@ -90,6 +102,56 @@ class Encoder(object):
ids
[
key
]
=
doc_ids
return
ids
,
len
(
json_line
)
def
encode_with_ner
(
self
,
json_line
):
if
self
.
spacy
is
None
:
raise
ValueError
(
'Cannot do NER without spacy'
)
data
=
json
.
loads
(
json_line
)
ids
=
{}
ner_masks
=
{}
for
key
in
self
.
args
.
json_keys
:
text
=
data
[
key
]
doc_ids
=
[]
doc_ner_mask
=
[]
for
sentence
in
Encoder
.
splitter
.
tokenize
(
text
):
sentence_ids
=
Encoder
.
tokenizer
.
tokenize
(
sentence
)
if
len
(
sentence_ids
)
>
0
:
doc_ids
.
append
(
sentence_ids
)
# sentence is cased?
# print(sentence)
entities
=
self
.
spacy
(
sentence
).
ents
undesired_types
=
[
'CARDINAL'
,
'TIME'
,
'PERCENT'
,
'MONEY'
,
'QUANTITY'
,
'ORDINAL'
]
entities
=
[
e
for
e
in
entities
if
e
.
text
!=
"CLS"
and
e
.
label_
not
in
undesired_types
]
# entities = []
masked_positions
=
[]
if
len
(
entities
)
>
0
:
entity_idx
=
np
.
random
.
randint
(
0
,
len
(
entities
))
selected_entity
=
entities
[
entity_idx
]
token_pos_map
=
id_to_str_pos_map
(
sentence_ids
,
Encoder
.
tokenizer
)
mask_start
=
mask_end
=
0
set_mask_start
=
False
while
mask_end
<
len
(
token_pos_map
)
and
token_pos_map
[
mask_end
]
<
selected_entity
.
end_char
:
if
token_pos_map
[
mask_start
]
>
selected_entity
.
start_char
:
set_mask_start
=
True
if
not
set_mask_start
:
mask_start
+=
1
mask_end
+=
1
masked_positions
=
list
(
range
(
mask_start
-
1
,
mask_end
))
ner_mask
=
[
0
]
*
len
(
sentence_ids
)
for
pos
in
masked_positions
:
ner_mask
[
pos
]
=
1
doc_ner_mask
.
append
(
ner_mask
)
if
self
.
args
.
append_eod
:
doc_ids
[
-
1
].
append
(
Encoder
.
tokenizer
.
eod
)
doc_ner_mask
[
-
1
].
append
(
0
)
ids
[
key
]
=
doc_ids
ner_masks
[
key
+
'-ner'
]
=
doc_ner_mask
return
ids
,
ner_masks
,
len
(
json_line
)
def
get_args
():
parser
=
argparse
.
ArgumentParser
()
group
=
parser
.
add_argument_group
(
title
=
'input data'
)
...
...
@@ -126,6 +188,8 @@ def get_args():
help
=
'Number of worker processes to launch'
)
group
.
add_argument
(
'--log-interval'
,
type
=
int
,
default
=
100
,
help
=
'Interval between progress updates'
)
group
.
add_argument
(
'--create-ner-masks'
,
action
=
'store_true'
,
help
=
'Also create mask tensors for salient span masking'
)
args
=
parser
.
parse_args
()
args
.
keep_empty
=
False
...
...
@@ -153,8 +217,11 @@ def main():
encoder
=
Encoder
(
args
)
tokenizer
=
build_tokenizer
(
args
)
pool
=
multiprocessing
.
Pool
(
args
.
workers
,
initializer
=
encoder
.
initializer
)
encoded_docs
=
pool
.
imap
(
encoder
.
encode
,
fin
,
25
)
#encoded_docs = map(encoder.encode, fin)
if
args
.
create_ner_masks
:
encoded_docs
=
pool
.
imap
(
encoder
.
encode_with_ner
,
fin
,
25
)
else
:
encoded_docs
=
pool
.
imap
(
encoder
.
encode
,
fin
,
25
)
#encoded_docs = map(encoder.encode, fin)
level
=
"document"
if
args
.
split_sentences
:
...
...
@@ -165,7 +232,10 @@ def main():
output_bin_files
=
{}
output_idx_files
=
{}
builders
=
{}
for
key
in
args
.
json_keys
:
output_keys
=
args
.
json_keys
.
copy
()
if
args
.
create_ner_masks
:
output_keys
.
extend
([
key
+
'-ner'
for
key
in
output_keys
])
for
key
in
output_keys
:
output_bin_files
[
key
]
=
"{}_{}_{}.bin"
.
format
(
args
.
output_prefix
,
key
,
level
)
output_idx_files
[
key
]
=
"{}_{}_{}.idx"
.
format
(
args
.
output_prefix
,
...
...
@@ -179,12 +249,24 @@ def main():
total_bytes_processed
=
0
print
(
"Time to startup:"
,
startup_end
-
startup_start
)
for
i
,
(
doc
,
bytes_processed
)
in
enumerate
(
encoded_docs
,
start
=
1
):
# for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
for
i
,
doc_data
in
enumerate
(
encoded_docs
,
start
=
1
):
if
args
.
create_ner_masks
:
doc
,
ner_masks
,
bytes_processed
=
doc_data
else
:
doc
,
bytes_processed
=
doc_data
total_bytes_processed
+=
bytes_processed
for
key
,
sentences
in
doc
.
items
():
for
sentence
in
sentences
:
builders
[
key
].
add_item
(
torch
.
IntTensor
(
sentence
))
builders
[
key
].
end_document
()
if
args
.
create_ner_masks
:
for
key
,
sentence_masks
in
ner_masks
.
items
():
for
mask
in
sentence_masks
:
builders
[
key
].
add_item
(
torch
.
IntTensor
(
mask
))
builders
[
key
].
end_document
()
if
i
%
args
.
log_interval
==
0
:
current
=
time
.
time
()
elapsed
=
current
-
proc_start
...
...
@@ -193,7 +275,7 @@ def main():
f
"(
{
i
/
elapsed
}
docs/s,
{
mbs
}
MB/s)."
,
file
=
sys
.
stderr
)
for
key
in
args
.
json
_keys
:
for
key
in
output
_keys
:
builders
[
key
].
finalize
(
output_idx_files
[
key
])
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment