Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
4558e42f
Commit
4558e42f
authored
Apr 03, 2020
by
Neel Kant
Browse files
Implement InverseClozeDataset with IndexedDataset
parent
9350ee08
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
238 additions
and
62 deletions
+238
-62
megatron/data/bert_dataset.py
megatron/data/bert_dataset.py
+11
-4
megatron/data/ict_dataset.py
megatron/data/ict_dataset.py
+162
-0
pretrain_bert_ict.py
pretrain_bert_ict.py
+65
-58
No files found.
megatron/data/bert_dataset.py
View file @
4558e42f
...
...
@@ -27,13 +27,15 @@ from megatron import mpu
from
megatron.data
import
helpers
from
megatron.data.dataset_utils
import
build_training_sample
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
from
megatron.data.ict_dataset
import
InverseClozeDataset
from
megatron
import
print_rank_0
def
build_train_valid_test_datasets
(
data_prefix
,
data_impl
,
splits_string
,
train_valid_test_num_samples
,
max_seq_length
,
masked_lm_prob
,
short_seq_prob
,
seed
,
skip_warmup
):
short_seq_prob
,
seed
,
skip_warmup
,
ict_dataset
=
False
):
# Indexed dataset.
indexed_dataset
=
get_indexed_dataset_
(
data_prefix
,
...
...
@@ -74,16 +76,21 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
# New doc_idx view.
indexed_dataset
.
set_doc_idx
(
doc_idx_ptr
[
start_index
:
end_index
])
# Build the dataset accordingly.
dataset
=
BertDatase
t
(
kwargs
=
dic
t
(
name
=
name
,
indexed_dataset
=
indexed_dataset
,
data_prefix
=
data_prefix
,
num_epochs
=
None
,
max_num_samples
=
train_valid_test_num_samples
[
index
],
masked_lm_prob
=
masked_lm_prob
,
max_seq_length
=
max_seq_length
,
short_seq_prob
=
short_seq_prob
,
seed
=
seed
)
seed
=
seed
)
if
ict_dataset
:
dataset
=
InverseClozeDataset
(
**
kwargs
)
else
:
dataset
=
BertDataset
(
masked_lm_prob
=
masked_lm_prob
,
**
kwargs
)
# Set the original pointer so dataset remains the main dataset.
indexed_dataset
.
set_doc_idx
(
doc_idx_ptr
)
# Checks.
...
...
megatron/data/ict_dataset.py
0 → 100644
View file @
4558e42f
import
random
import
numpy
as
np
from
torch.utils.data
import
Dataset
from
megatron
import
get_tokenizer
from
.bert_dataset
import
get_samples_mapping_
class
InverseClozeDataset
(
Dataset
):
"""Dataset containing sentences and various 'blocks' for an inverse cloze task."""
def
__init__
(
self
,
name
,
indexed_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
short_seq_prob
,
seed
):
self
.
name
=
name
self
.
seed
=
seed
self
.
max_seq_length
=
max_seq_length
self
.
indexed_dataset
=
indexed_dataset
self
.
samples_mapping
=
get_samples_mapping_
(
self
.
indexed_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
self
.
max_seq_length
,
short_seq_prob
,
self
.
seed
,
self
.
name
)
tokenizer
=
get_tokenizer
()
self
.
vocab_id_list
=
list
(
tokenizer
.
inv_vocab
.
keys
())
self
.
vocab_id_to_token_list
=
tokenizer
.
inv_vocab
self
.
cls_id
=
tokenizer
.
cls
self
.
sep_id
=
tokenizer
.
sep
self
.
mask_id
=
tokenizer
.
mask
self
.
pad_id
=
tokenizer
.
pad
def
__len__
(
self
):
return
self
.
samples_mapping
.
shape
[
0
]
def
__getitem__
(
self
,
idx
):
# get rng state corresponding to index (allows deterministic random pair)
rng
=
random
.
Random
(
idx
+
1000
)
np_rng
=
np
.
random
.
RandomState
(
seed
=
[
rng
.
randint
(
0
,
2
**
32
-
1
)
for
_
in
range
(
16
)])
# get seq length. Save 2 tokens for beginning and end
target_seq_length
=
self
.
max_seq_length
-
2
if
rng
.
random
()
<
self
.
short_seq_prob
:
target_seq_length
=
rng
.
randint
(
5
,
target_seq_length
)
input_data
,
context_data
=
self
.
get_input_and_context
(
target_seq_length
,
rng
,
np_rng
)
input_tokens
,
input_token_types
,
input_pad_mask
=
input_data
context_tokens
,
context_token_types
,
context_pad_mask
=
context_data
sample
=
{
'input_text'
:
np
.
array
(
input_tokens
),
'input_types'
:
np
.
array
(
input_token_types
),
'input_pad_mask'
:
np
.
array
(
input_pad_mask
),
'context_text'
:
np
.
array
(
context_tokens
),
'context_types'
:
np
.
array
(
context_token_types
),
'context_pad_mask'
:
np
.
array
(
context_pad_mask
)
}
return
sample
def
get_sentence_split_doc
(
self
,
idx
):
"""fetch document at index idx and split into sentences"""
document
=
self
.
indexed_dataset
[
idx
]
if
isinstance
(
document
,
dict
):
document
=
document
[
'text'
]
lines
=
document
.
split
(
'
\n
'
)
return
[
line
for
line
in
lines
if
line
]
def
sentence_tokenize
(
self
,
sent
,
sentence_num
=
0
):
"""tokenize sentence and get token types"""
tokens
=
self
.
tokenizer
.
EncodeAsIds
(
sent
).
tokenization
str_type
=
'str'
+
str
(
sentence_num
)
token_types
=
[
self
.
tokenizer
.
get_type
(
str_type
).
Id
]
*
len
(
tokens
)
return
tokens
,
token_types
def
concat_and_pad_tokens
(
self
,
tokens
,
token_types
):
"""concat with special tokens and pad sequence to self.max_seq_length"""
tokens
=
[
self
.
cls_id
]
+
tokens
+
[
self
.
sep_id
]
token_types
=
[
token_types
[
0
]]
+
token_types
+
[
token_types
[
0
]]
assert
len
(
tokens
)
<=
self
.
max_seq_length
num_pad
=
max
(
0
,
self
.
max_seq_length
-
len
(
tokens
))
pad_mask
=
[
0
]
*
len
(
tokens
)
+
[
1
]
*
num_pad
tokens
+=
[
self
.
pad_id
]
*
num_pad
token_types
+=
[
token_types
[
0
]]
*
num_pad
return
tokens
,
token_types
,
pad_mask
def
get_input_and_context
(
self
,
target_seq_length
,
rng
,
np_rng
):
"""fetches a sentence and its surrounding context"""
num_tries
=
0
while
num_tries
<
20
:
num_tries
+=
1
doc
=
None
while
doc
is
None
:
doc_idx
=
np_rng
.
randint
(
len
(
self
)
-
1
)
# doc is a list of sentences
doc
=
self
.
get_sentence_split_doc
(
doc_idx
)
if
not
doc
:
doc
=
None
# set up and tokenize the entire selected document
num_sentences
=
len
(
doc
)
padless_max_len
=
self
.
max_seq_length
-
2
# select a random sentence from the document as input
# TODO: consider adding multiple input sentences.
input_sentence_idx
=
rng
.
randint
(
0
,
num_sentences
-
1
)
tokens
,
token_types
=
self
.
sentence_tokenize
(
doc
[
input_sentence_idx
],
0
)
input_tokens
,
input_token_types
=
tokens
[:
target_seq_length
],
token_types
[:
target_seq_length
]
if
not
len
(
input_tokens
)
>
0
:
continue
context_tokens
,
context_token_types
=
[],
[]
# 10% of the time, the input sentence is left in the context.
# The other 90% of the time, keep it out.
if
rng
.
random
()
<
0.1
:
context_tokens
=
input_tokens
.
copy
()
context_token_types
=
input_token_types
.
copy
()
# parameters for examining sentences to add to the context
view_preceding
=
True
view_radius
=
1
while
len
(
context_tokens
)
<
padless_max_len
:
# keep adding sentences while the context can accommodate more.
if
view_preceding
:
examine_idx
=
input_sentence_idx
-
view_radius
if
examine_idx
>=
0
:
new_tokens
,
new_token_types
=
self
.
sentence_tokenize
(
doc
[
examine_idx
],
0
)
context_tokens
=
new_tokens
+
context_tokens
context_token_types
=
new_token_types
+
context_token_types
else
:
examine_idx
=
input_sentence_idx
+
view_radius
if
examine_idx
<
num_sentences
:
new_tokens
,
new_token_types
=
self
.
sentence_tokenize
(
doc
[
examine_idx
],
0
)
context_tokens
+=
new_tokens
context_token_types
+=
new_token_types
view_radius
+=
1
view_preceding
=
not
view_preceding
if
view_radius
>
num_sentences
:
break
# assemble the tokens and token types of the context
context_tokens
=
context_tokens
[:
padless_max_len
]
context_token_types
=
context_token_types
[:
padless_max_len
]
if
not
len
(
context_tokens
)
>
0
:
continue
# concatenate 'CLS' and 'SEP' tokens and add extra token types
input_tokens
,
input_token_types
,
input_pad_mask
=
self
.
concat_and_pad_tokens
(
input_tokens
,
input_token_types
)
context_tokens
,
context_token_types
,
context_pad_mask
=
self
.
concat_and_pad_tokens
(
context_tokens
,
context_token_types
)
return
(
input_tokens
,
input_token_types
,
input_pad_mask
),
\
(
context_tokens
,
context_token_types
,
context_pad_mask
)
else
:
raise
RuntimeError
(
"Could not get a valid data point from InverseClozeDataset"
)
pretrain_bert_ict.py
View file @
4558e42f
...
...
@@ -18,43 +18,32 @@
import
torch
import
torch.nn.functional
as
F
from
configure_data
import
configure_data
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron.data.bert_dataset
import
build_train_valid_test_datasets
from
megatron.model
import
ICTBertModel
from
megatron.utils
import
print_rank_0
from
megatron.training
import
pretrain
from
megatron.utils
import
make_data_loader
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
vocab_size_with_padding
from
megatron.training
import
run
num_batches
=
0
def
model_provider
(
args
):
def
model_provider
():
"""Build the model."""
args
=
get_args
()
print_rank_0
(
'building BERT models ...'
)
model
=
ICTBertModel
(
num_layers
=
args
.
num_layers
,
vocab_size
=
args
.
vocab_size
,
hidden_size
=
args
.
hidden_size
,
num_attention_heads
=
args
.
num_attention_heads
,
embedding_dropout_prob
=
args
.
hidden_dropout
,
attention_dropout_prob
=
args
.
attention_dropout
,
output_dropout_prob
=
args
.
hidden_dropout
,
max_sequence_length
=
args
.
max_position_embeddings
,
checkpoint_activations
=
args
.
checkpoint_activations
,
ict_head_size
=
128
,
checkpoint_num_layers
=
args
.
checkpoint_num_layers
,
layernorm_epsilon
=
args
.
layernorm_epsilon
,
num_tokentypes
=
args
.
tokentype_size
,
parallel_output
=
True
,
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
,
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
)
num_tokentypes
=
2
,
parallel_output
=
True
)
return
model
def
get_batch
(
data_iterator
,
timers
):
def
get_batch
(
data_iterator
):
# Items and their type.
keys
=
[
'input_text'
,
'input_types'
,
'input_pad_mask'
,
...
...
@@ -62,13 +51,10 @@ def get_batch(data_iterator, timers):
datatype
=
torch
.
int64
# Broadcast data.
timers
(
'data loader'
).
start
()
if
data_iterator
is
None
:
data
=
None
else
:
data
=
next
(
data_iterator
)
timers
(
'data loader'
).
stop
()
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
...
...
@@ -83,17 +69,17 @@ def get_batch(data_iterator, timers):
context_tokens
,
context_types
,
context_pad_mask
def
forward_step
(
data_iterator
,
model
,
args
,
timers
):
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
timers
=
get_timers
()
# Get the batch.
timers
(
'batch generator'
).
start
()
input_tokens
,
input_types
,
input_pad_mask
,
\
context_tokens
,
context_types
,
context_pad_mask
=
get_batch
(
data_iterator
,
timers
)
context_tokens
,
context_types
,
context_pad_mask
=
get_batch
(
data_iterator
)
timers
(
'batch generator'
).
stop
()
# Forward model.
# TODO: important to make sure that everything, including padding mask is as expected here.
retrieval_scores
=
model
(
input_tokens
,
input_pad_mask
,
input_types
,
context_tokens
,
context_pad_mask
,
context_types
).
float
()
...
...
@@ -112,50 +98,71 @@ def forward_step(data_iterator, model, args, timers):
'top5_acc'
:
reduced_losses
[
2
]}
def
get_train_val_test_data
(
args
):
def
get_train_val_test_data
():
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
args
=
get_args
()
(
train_data
,
val_data
,
test_data
)
=
(
None
,
None
,
None
)
# Data loader only on rank 0 of each model parallel group.
if
mpu
.
get_model_parallel_rank
()
==
0
:
if
(
args
.
data_loader
==
'raw'
or
args
.
data_loader
==
'lazy'
or
args
.
data_loader
==
'tfrecords'
):
data_config
=
configure_data
()
ds_type
=
'BERT_ict'
data_config
.
set_defaults
(
data_set_type
=
ds_type
,
transpose
=
False
)
(
train_data
,
val_data
,
test_data
),
tokenizer
=
data_config
.
apply
(
args
)
num_tokens
=
vocab_size_with_padding
(
tokenizer
.
num_tokens
,
args
)
# Need to broadcast num_tokens and num_type_tokens.
token_counts
=
torch
.
cuda
.
LongTensor
([
num_tokens
,
tokenizer
.
num_type_tokens
,
int
(
args
.
do_train
),
int
(
args
.
do_valid
),
int
(
args
.
do_test
)])
else
:
print
(
"Unsupported data loader for BERT."
)
exit
(
1
)
print_rank_0
(
'> building train, validation, and test datasets '
'for BERT ...'
)
data_parallel_size
=
mpu
.
get_data_parallel_world_size
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
global_batch_size
=
args
.
batch_size
*
data_parallel_size
# Number of train/valid/test samples.
train_iters
=
args
.
train_iters
eval_iters
=
(
train_iters
//
args
.
eval_interval
+
1
)
*
args
.
eval_iters
test_iters
=
args
.
eval_iters
train_val_test_num_samples
=
[
train_iters
*
global_batch_size
,
eval_iters
*
global_batch_size
,
test_iters
*
global_batch_size
]
print_rank_0
(
' > datasets target sizes (minimum size):'
)
print_rank_0
(
' train: {}'
.
format
(
train_val_test_num_samples
[
0
]))
print_rank_0
(
' validation: {}'
.
format
(
train_val_test_num_samples
[
1
]))
print_rank_0
(
' test: {}'
.
format
(
train_val_test_num_samples
[
2
]))
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets
(
data_prefix
=
args
.
data_path
,
data_impl
=
args
.
data_impl
,
splits_string
=
args
.
split
,
train_valid_test_num_samples
=
train_val_test_num_samples
,
max_seq_length
=
args
.
seq_length
,
masked_lm_prob
=
args
.
mask_prob
,
short_seq_prob
=
args
.
short_seq_prob
,
seed
=
args
.
seed
,
skip_warmup
=
(
not
args
.
mmap_warmup
),
ict_dataset
=
True
)
print_rank_0
(
"> finished creating BERT ICT datasets ..."
)
train_data
=
make_data_loader
(
train_ds
)
valid_data
=
make_data_loader
(
valid_ds
)
test_data
=
make_data_loader
(
test_ds
)
do_train
=
train_data
is
not
None
and
args
.
train_iters
>
0
do_valid
=
valid_data
is
not
None
and
args
.
eval_iters
>
0
do_test
=
test_data
is
not
None
and
args
.
eval_iters
>
0
# Need to broadcast num_tokens and num_type_tokens.
flags
=
torch
.
cuda
.
LongTensor
(
[
int
(
do_train
),
int
(
do_valid
),
int
(
do_test
)])
else
:
token_count
s
=
torch
.
cuda
.
LongTensor
([
0
,
0
,
0
,
0
,
0
])
flag
s
=
torch
.
cuda
.
LongTensor
([
0
,
0
,
0
])
# Broadcast num tokens.
torch
.
distributed
.
broadcast
(
token_count
s
,
torch
.
distributed
.
broadcast
(
flag
s
,
mpu
.
get_model_parallel_src_rank
(),
group
=
mpu
.
get_model_parallel_group
())
num_tokens
=
token_counts
[
0
].
item
()
num_type_tokens
=
token_counts
[
1
].
item
()
args
.
do_train
=
token_counts
[
2
].
item
()
args
.
do_valid
=
token_counts
[
3
].
item
()
args
.
do_test
=
token_counts
[
4
].
item
()
args
.
vocab_size
=
num_tokens
args
.
tokentype_size
=
num_type_tokens
args
.
do_train
=
flags
[
0
].
item
()
args
.
do_valid
=
flags
[
1
].
item
()
args
.
do_test
=
flags
[
2
].
item
()
return
train_data
,
val_data
,
test_data
if
__name__
==
"__main__"
:
run
(
'P
retrain
ICT BERT model'
,
get_train_val_test_data
,
model_provider
,
forward_step
)
p
retrain
(
get_train_val_test_data
,
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
}
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment