Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
6495927e
Commit
6495927e
authored
Jun 22, 2020
by
Neel Kant
Browse files
Address most of comments from code review w/ Mohammad
parent
fcc500d6
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
264 additions
and
189 deletions
+264
-189
megatron/arguments.py
megatron/arguments.py
+2
-2
megatron/checkpointing.py
megatron/checkpointing.py
+2
-5
megatron/data/bert_dataset.py
megatron/data/bert_dataset.py
+75
-9
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+7
-67
megatron/data/helpers.cpp
megatron/data/helpers.cpp
+7
-1
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+18
-23
megatron/data/realm_dataset_utils.py
megatron/data/realm_dataset_utils.py
+3
-1
megatron/model/bert_model.py
megatron/model/bert_model.py
+17
-71
megatron/model/realm_model.py
megatron/model/realm_model.py
+72
-5
megatron/model/utils.py
megatron/model/utils.py
+39
-0
megatron/tokenizer/tokenizer.py
megatron/tokenizer/tokenizer.py
+8
-0
megatron/training.py
megatron/training.py
+3
-2
pretrain_bert_ict.py
pretrain_bert_ict.py
+11
-3
No files found.
megatron/arguments.py
View file @
6495927e
...
...
@@ -136,6 +136,8 @@ def _add_network_size_args(parser):
' grouped: [1, 2, 1, 2] and spaced: [1, 1, 2, 2].'
)
group
.
add_argument
(
'--hidden-size'
,
type
=
int
,
default
=
None
,
help
=
'Tansformer hidden size.'
)
group
.
add_argument
(
'--ict-head-size'
,
type
=
int
,
default
=
None
,
help
=
'Size of block embeddings to be used in ICT and REALM (paper default: 128)'
)
group
.
add_argument
(
'--num-attention-heads'
,
type
=
int
,
default
=
None
,
help
=
'Number of transformer attention heads.'
)
group
.
add_argument
(
'--max-position-embeddings'
,
type
=
int
,
default
=
None
,
...
...
@@ -202,8 +204,6 @@ def _add_training_args(parser):
def
_add_initialization_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'initialization'
)
group
.
add_argument
(
'--debug'
,
action
=
'store_true'
,
help
=
'Run things in debug mode'
)
group
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
1234
,
help
=
'Random seed used for python, numpy, '
'pytorch, and cuda.'
)
...
...
megatron/checkpointing.py
View file @
6495927e
...
...
@@ -128,13 +128,10 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
torch
.
distributed
.
barrier
()
def
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
):
def
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
,
load_arg
=
'load'
):
"""Load a model checkpoint and return the iteration."""
args
=
get_args
()
load_dir
=
args
.
load
from
megatron.model.bert_model
import
BertModel
if
isinstance
(
model
,
BertModel
)
and
args
.
bert_load
is
not
None
:
load_dir
=
args
.
bert_load
load_dir
=
getattr
(
args
,
load_arg
)
if
isinstance
(
model
,
torchDDP
):
model
=
model
.
module
...
...
megatron/data/bert_dataset.py
View file @
6495927e
...
...
@@ -25,6 +25,11 @@ from torch.utils.data import Dataset
from
megatron
import
get_tokenizer
,
get_args
from
megatron
import
mpu
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
from
megatron.data.dataset_utils
import
get_a_and_b_segments
from
megatron.data.dataset_utils
import
truncate_segments
from
megatron.data.dataset_utils
import
create_tokens_and_tokentypes
from
megatron.data.dataset_utils
import
pad_and_convert_to_numpy
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
from
megatron
import
print_rank_0
...
...
@@ -61,8 +66,6 @@ class BertDataset(Dataset):
self
.
sep_id
=
tokenizer
.
sep
self
.
mask_id
=
tokenizer
.
mask
self
.
pad_id
=
tokenizer
.
pad
from
megatron.data.dataset_utils
import
build_training_sample
self
.
build_sample_fn
=
build_training_sample
def
__len__
(
self
):
return
self
.
samples_mapping
.
shape
[
0
]
...
...
@@ -73,13 +76,13 @@ class BertDataset(Dataset):
# Note that this rng state should be numpy and not python since
# python randint is inclusive whereas the numpy one is exclusive.
np_rng
=
np
.
random
.
RandomState
(
seed
=
(
self
.
seed
+
idx
))
return
self
.
build_sample
_fn
(
sample
,
seq_length
,
self
.
max_seq_length
,
# needed for padding
self
.
vocab_id_list
,
self
.
vocab_id_to_token_dict
,
self
.
cls_id
,
self
.
sep_id
,
self
.
mask_id
,
self
.
pad_id
,
self
.
masked_lm_prob
,
np_rng
)
return
build_
training_
sample
(
sample
,
seq_length
,
self
.
max_seq_length
,
# needed for padding
self
.
vocab_id_list
,
self
.
vocab_id_to_token_dict
,
self
.
cls_id
,
self
.
sep_id
,
self
.
mask_id
,
self
.
pad_id
,
self
.
masked_lm_prob
,
np_rng
)
def
get_indexed_dataset_
(
data_prefix
,
data_impl
,
skip_warmup
):
...
...
@@ -214,3 +217,66 @@ def get_samples_mapping_(indexed_dataset,
samples_mapping
.
shape
[
0
]))
return
samples_mapping
def
build_training_sample
(
sample
,
target_seq_length
,
max_seq_length
,
vocab_id_list
,
vocab_id_to_token_dict
,
cls_id
,
sep_id
,
mask_id
,
pad_id
,
masked_lm_prob
,
np_rng
):
"""Biuld training sample.
Arguments:
sample: A list of sentences in which each sentence is a list token ids.
target_seq_length: Desired sequence length.
max_seq_length: Maximum length of the sequence. All values are padded to
this length.
vocab_id_list: List of vocabulary ids. Used to pick a random id.
vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
cls_id: Start of example id.
sep_id: Separator id.
mask_id: Mask token id.
pad_id: Padding token id.
masked_lm_prob: Probability to mask tokens.
np_rng: Random number genenrator. Note that this rng state should be
numpy and not python since python randint is inclusive for
the opper bound whereas the numpy one is exclusive.
"""
# We assume that we have at least two sentences in the sample
assert
len
(
sample
)
>
1
assert
target_seq_length
<=
max_seq_length
# Divide sample into two segments (A and B).
tokens_a
,
tokens_b
,
is_next_random
=
get_a_and_b_segments
(
sample
,
np_rng
)
# Truncate to `target_sequence_length`.
max_num_tokens
=
target_seq_length
truncated
=
truncate_segments
(
tokens_a
,
tokens_b
,
len
(
tokens_a
),
len
(
tokens_b
),
max_num_tokens
,
np_rng
)
# Build tokens and toketypes.
tokens
,
tokentypes
=
create_tokens_and_tokentypes
(
tokens_a
,
tokens_b
,
cls_id
,
sep_id
)
# Masking.
max_predictions_per_seq
=
masked_lm_prob
*
max_num_tokens
(
tokens
,
masked_positions
,
masked_labels
,
_
)
=
create_masked_lm_predictions
(
tokens
,
vocab_id_list
,
vocab_id_to_token_dict
,
masked_lm_prob
,
cls_id
,
sep_id
,
mask_id
,
max_predictions_per_seq
,
np_rng
)
# Padding.
tokens_np
,
tokentypes_np
,
labels_np
,
padding_mask_np
,
loss_mask_np
\
=
pad_and_convert_to_numpy
(
tokens
,
tokentypes
,
masked_positions
,
masked_labels
,
pad_id
,
max_seq_length
)
train_sample
=
{
'text'
:
tokens_np
,
'types'
:
tokentypes_np
,
'labels'
:
labels_np
,
'is_random'
:
int
(
is_next_random
),
'loss_mask'
:
loss_mask_np
,
'padding_mask'
:
padding_mask_np
,
'truncated'
:
int
(
truncated
)}
return
train_sample
megatron/data/dataset_utils.py
View file @
6495927e
...
...
@@ -23,9 +23,11 @@ import itertools
import
numpy
as
np
from
megatron
import
print_rank_0
,
get_args
from
megatron.data.bert_dataset
import
get_indexed_dataset_
,
get_train_valid_test_split_
DATASET_TYPES
=
[
'standard_bert'
,
'ict'
]
DSET_TYPE_STD
=
'standard_bert'
DSET_TYPE_ICT
=
'ict'
DSET_TYPES
=
[
DSET_TYPE_ICT
,
DSET_TYPE_STD
]
def
compile_helper
():
"""Compile helper function ar runtime. Make sure this
...
...
@@ -40,68 +42,6 @@ def compile_helper():
sys
.
exit
(
1
)
def
build_training_sample
(
sample
,
target_seq_length
,
max_seq_length
,
vocab_id_list
,
vocab_id_to_token_dict
,
cls_id
,
sep_id
,
mask_id
,
pad_id
,
masked_lm_prob
,
np_rng
):
"""Biuld training sample.
Arguments:
sample: A list of sentences in which each sentence is a list token ids.
target_seq_length: Desired sequence length.
max_seq_length: Maximum length of the sequence. All values are padded to
this length.
vocab_id_list: List of vocabulary ids. Used to pick a random id.
vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
cls_id: Start of example id.
sep_id: Separator id.
mask_id: Mask token id.
pad_id: Padding token id.
masked_lm_prob: Probability to mask tokens.
np_rng: Random number genenrator. Note that this rng state should be
numpy and not python since python randint is inclusive for
the opper bound whereas the numpy one is exclusive.
"""
# We assume that we have at least two sentences in the sample
assert
len
(
sample
)
>
1
assert
target_seq_length
<=
max_seq_length
# Divide sample into two segments (A and B).
tokens_a
,
tokens_b
,
is_next_random
=
get_a_and_b_segments
(
sample
,
np_rng
)
# Truncate to `target_sequence_length`.
max_num_tokens
=
target_seq_length
truncated
=
truncate_segments
(
tokens_a
,
tokens_b
,
len
(
tokens_a
),
len
(
tokens_b
),
max_num_tokens
,
np_rng
)
# Build tokens and toketypes.
tokens
,
tokentypes
=
create_tokens_and_tokentypes
(
tokens_a
,
tokens_b
,
cls_id
,
sep_id
)
# Masking.
max_predictions_per_seq
=
masked_lm_prob
*
max_num_tokens
(
tokens
,
masked_positions
,
masked_labels
,
_
)
=
create_masked_lm_predictions
(
tokens
,
vocab_id_list
,
vocab_id_to_token_dict
,
masked_lm_prob
,
cls_id
,
sep_id
,
mask_id
,
max_predictions_per_seq
,
np_rng
)
# Padding.
tokens_np
,
tokentypes_np
,
labels_np
,
padding_mask_np
,
loss_mask_np
\
=
pad_and_convert_to_numpy
(
tokens
,
tokentypes
,
masked_positions
,
masked_labels
,
pad_id
,
max_seq_length
)
train_sample
=
{
'text'
:
tokens_np
,
'types'
:
tokentypes_np
,
'labels'
:
labels_np
,
'is_random'
:
int
(
is_next_random
),
'loss_mask'
:
loss_mask_np
,
'padding_mask'
:
padding_mask_np
,
'truncated'
:
int
(
truncated
)}
return
train_sample
def
get_a_and_b_segments
(
sample
,
np_rng
):
"""Divide sample into a and b segments."""
...
...
@@ -418,7 +358,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
short_seq_prob
,
seed
,
skip_warmup
,
dataset_type
=
'standard_bert'
):
if
dataset_type
not
in
D
ATA
SET_TYPES
:
if
dataset_type
not
in
DSET_TYPES
:
raise
ValueError
(
"Invalid dataset_type: "
,
dataset_type
)
# Indexed dataset.
...
...
@@ -426,7 +366,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
data_impl
,
skip_warmup
)
if
dataset_type
in
[
'ict'
]
:
if
dataset_type
==
DSET_TYPE_ICT
:
args
=
get_args
()
title_dataset
=
get_indexed_dataset_
(
args
.
titles_data_path
,
data_impl
,
...
...
@@ -479,7 +419,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
seed
=
seed
)
if
dataset_type
==
'ict'
:
if
dataset_type
==
DSET_TYPE_ICT
:
args
=
get_args
()
dataset
=
ICTDataset
(
block_dataset
=
indexed_dataset
,
...
...
megatron/data/helpers.cpp
View file @
6495927e
...
...
@@ -452,10 +452,12 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Current map index.
uint64_t
map_index
=
0
;
int32_t
block_id
=
0
;
// For each epoch:
for
(
int32_t
epoch
=
0
;
epoch
<
num_epochs
;
++
epoch
)
{
// assign every block a unique id
int32_t
block_id
=
0
;
if
(
map_index
>=
max_num_samples
)
{
if
(
verbose
&&
(
!
second
))
{
cout
<<
" reached "
<<
max_num_samples
<<
" samples after "
...
...
@@ -516,6 +518,10 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Populate the map.
if
(
second
)
{
const
auto
map_index_0
=
4
*
map_index
;
// Each sample has 4 items: the starting sentence index, ending sentence index,
// the index of the document from which the block comes (used for fetching titles)
// and the unique id of the block (used for creating block indexes)
maps
[
map_index_0
]
=
static_cast
<
DocIdx
>
(
prev_start_index
);
maps
[
map_index_0
+
1
]
=
static_cast
<
DocIdx
>
(
sent_index
+
1
);
maps
[
map_index_0
+
2
]
=
static_cast
<
DocIdx
>
(
doc
);
...
...
megatron/data/realm_dataset.py
View file @
6495927e
...
...
@@ -41,14 +41,15 @@ class ICTDataset(Dataset):
"""Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
start_idx
,
end_idx
,
doc_idx
,
block_idx
=
self
.
samples_mapping
[
idx
]
if
self
.
use_titles
:
title
=
list
(
self
.
title_dataset
[
int
(
doc_idx
)]
)
title
=
self
.
title_dataset
[
int
(
doc_idx
)]
title_pad_offset
=
3
+
len
(
title
)
else
:
title
=
None
title_pad_offset
=
2
block
=
[
list
(
self
.
block_dataset
[
i
]
)
for
i
in
range
(
start_idx
,
end_idx
)]
block
=
[
self
.
block_dataset
[
i
]
for
i
in
range
(
start_idx
,
end_idx
)]
assert
len
(
block
)
>
1
# randint() is inclusive for Python rng
rand_sent_idx
=
self
.
rng
.
randint
(
0
,
len
(
block
)
-
1
)
# keep the query in the context query_in_block_prob fraction of the time.
...
...
@@ -64,53 +65,47 @@ class ICTDataset(Dataset):
query_tokens
,
query_pad_mask
=
self
.
concat_and_pad_tokens
(
query
)
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
block_data
=
np
.
array
([
start_idx
,
end_idx
,
doc_idx
,
block_idx
]).
astype
(
np
.
int64
)
sample
=
{
'query_tokens'
:
np
.
array
(
query_tokens
)
,
'query_pad_mask'
:
np
.
array
(
query_pad_mask
)
,
'block_tokens'
:
np
.
array
(
block_tokens
)
,
'block_pad_mask'
:
np
.
array
(
block_pad_mask
)
,
'block_data'
:
np
.
array
([
start_idx
,
end_idx
,
doc_idx
,
block_idx
]).
astype
(
np
.
int64
)
'query_tokens'
:
query_tokens
,
'query_pad_mask'
:
query_pad_mask
,
'block_tokens'
:
block_tokens
,
'block_pad_mask'
:
block_pad_mask
,
'block_data'
:
block_data
,
}
return
sample
def
encode_text
(
self
,
text
):
return
self
.
tokenizer
.
tokenize
(
text
)
def
decode_tokens
(
self
,
token_ids
):
"""Utility function to help with debugging mostly"""
tokens
=
self
.
tokenizer
.
tokenizer
.
convert_ids_to_tokens
(
token_ids
)
exclude_list
=
[
'[PAD]'
,
'[CLS]'
]
non_pads
=
[
t
for
t
in
tokens
if
t
not
in
exclude_list
]
joined_strs
=
join_str_list
(
non_pads
)
def
get_block
(
self
,
start_idx
,
end_idx
,
doc_idx
):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
block
=
[
list
(
self
.
block_dataset
[
i
]
)
for
i
in
range
(
start_idx
,
end_idx
)]
title
=
list
(
self
.
title_dataset
[
int
(
doc_idx
)]
)
block
=
[
self
.
block_dataset
[
i
]
for
i
in
range
(
start_idx
,
end_idx
)]
title
=
self
.
title_dataset
[
int
(
doc_idx
)]
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
(
3
+
len
(
title
))]
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
return
(
block_tokens
,
block_pad_mask
)
return
block_tokens
,
block_pad_mask
def
get_null_block
(
self
):
"""Get empty block and title - used in REALM pretraining"""
block
,
title
=
[],
[]
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
return
(
block_tokens
,
block_pad_mask
)
return
block_tokens
,
block_pad_mask
def
concat_and_pad_tokens
(
self
,
tokens
,
title
=
None
):
"""Concat with special tokens and pad sequence to self.max_seq_length"""
tokens
=
list
(
tokens
)
if
title
is
None
:
tokens
=
[
self
.
cls_id
]
+
tokens
+
[
self
.
sep_id
]
else
:
title
=
list
(
title
)
tokens
=
[
self
.
cls_id
]
+
title
+
[
self
.
sep_id
]
+
tokens
+
[
self
.
sep_id
]
assert
len
(
tokens
)
<=
self
.
max_seq_length
,
len
(
tokens
)
assert
len
(
tokens
)
<=
self
.
max_seq_length
num_pad
=
self
.
max_seq_length
-
len
(
tokens
)
pad_mask
=
[
1
]
*
len
(
tokens
)
+
[
0
]
*
num_pad
tokens
+=
[
self
.
pad_id
]
*
num_pad
return
tokens
,
pad_mask
return
np
.
array
(
tokens
),
np
.
array
(
pad_mask
)
megatron/data/realm_dataset_utils.py
View file @
6495927e
...
...
@@ -20,6 +20,8 @@ def join_str_list(str_list):
def
get_block_samples_mapping
(
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
seed
,
name
):
"""Get samples mapping for a dataset over fixed size blocks. This function also requires
a dataset of the titles for the source documents since their lengths must be taken into account."""
if
not
num_epochs
:
if
not
max_num_samples
:
raise
ValueError
(
"Need to specify either max_num_samples "
...
...
@@ -40,7 +42,7 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
indexmap_filename
+=
'.npy'
# Build the indexed mapping if not exist.
if
torch
.
distributed
.
get
_rank
()
==
0
and
\
if
mpu
.
get_data_parallel
_rank
()
==
0
and
\
not
os
.
path
.
isfile
(
indexmap_filename
):
print
(
' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
...
...
megatron/model/bert_model.py
View file @
6495927e
...
...
@@ -25,46 +25,12 @@ from megatron.model.utils import openai_gelu
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
from
megatron.model.utils
import
scaled_init_method_normal
from
megatron.model.utils
import
bert_attention_mask_func
from
megatron.model.utils
import
bert_extended_attention_mask
from
megatron.model.utils
import
bert_position_ids
from
megatron.module
import
MegatronModule
def
bert_attention_mask_func
(
attention_scores
,
attention_mask
):
attention_scores
=
attention_scores
+
attention_mask
return
attention_scores
def
bert_extended_attention_mask
(
attention_mask
,
dtype
):
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
attention_mask_b1s
=
attention_mask
.
unsqueeze
(
1
)
# [b, s, 1]
attention_mask_bs1
=
attention_mask
.
unsqueeze
(
2
)
# [b, s, s]
attention_mask_bss
=
attention_mask_b1s
*
attention_mask_bs1
# [b, 1, s, s]
extended_attention_mask
=
attention_mask_bss
.
unsqueeze
(
1
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0
# for masked positions, this operation will create a tensor which is
# 0.0 for positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
# fp16 compatibility
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
dtype
)
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
return
extended_attention_mask
def
bert_position_ids
(
token_ids
):
# Create position ids
seq_length
=
token_ids
.
size
(
1
)
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
token_ids
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
token_ids
)
return
position_ids
class
BertLMHead
(
MegatronModule
):
"""Masked LM head for Bert
...
...
@@ -110,40 +76,31 @@ class BertModel(MegatronModule):
"""Bert Language model."""
def
__init__
(
self
,
num_tokentypes
=
2
,
add_binary_head
=
True
,
ict_head_size
=
None
,
parallel_output
=
True
):
parallel_output
=
True
):
super
(
BertModel
,
self
).
__init__
()
args
=
get_args
()
self
.
add_binary_head
=
add_binary_head
self
.
ict_head_size
=
ict_head_size
self
.
add_ict_head
=
ict_head_size
is
not
None
assert
not
(
self
.
add_binary_head
and
self
.
add_ict_head
)
self
.
parallel_output
=
parallel_output
init_method
=
init_method_normal
(
args
.
init_method_std
)
add_pooler
=
self
.
add_binary_head
or
self
.
add_ict_head
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
attention_mask_func
=
bert_attention_mask_func
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
add_pooler
,
add_pooler
=
self
.
add_binary_head
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
)
if
not
self
.
add_ict_head
:
self
.
lm_head
=
BertLMHead
(
self
.
language_model
.
embedding
.
word_embeddings
.
weight
.
size
(
0
),
args
.
hidden_size
,
init_method
,
args
.
layernorm_epsilon
,
parallel_output
)
self
.
_lm_head_key
=
'lm_head'
self
.
lm_head
=
BertLMHead
(
self
.
language_model
.
embedding
.
word_embeddings
.
weight
.
size
(
0
),
args
.
hidden_size
,
init_method
,
args
.
layernorm_epsilon
,
parallel_output
)
self
.
_lm_head_key
=
'lm_head'
if
self
.
add_binary_head
:
self
.
binary_head
=
get_linear_layer
(
args
.
hidden_size
,
2
,
init_method
)
self
.
_binary_head_key
=
'binary_head'
elif
self
.
add_ict_head
:
self
.
ict_head
=
get_linear_layer
(
args
.
hidden_size
,
ict_head_size
,
init_method
)
self
.
_ict_head_key
=
'ict_head'
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
...
...
@@ -151,7 +108,7 @@ class BertModel(MegatronModule):
attention_mask
,
next
(
self
.
language_model
.
parameters
()).
dtype
)
position_ids
=
bert_position_ids
(
input_ids
)
if
self
.
add_binary_head
or
self
.
add_ict_head
:
if
self
.
add_binary_head
:
lm_output
,
pooled_output
=
self
.
language_model
(
input_ids
,
position_ids
,
...
...
@@ -165,12 +122,9 @@ class BertModel(MegatronModule):
tokentype_ids
=
tokentype_ids
)
# Output.
if
self
.
add_ict_head
:
ict_logits
=
self
.
ict_head
(
pooled_output
)
return
ict_logits
,
None
lm_logits
=
self
.
lm_head
(
lm_output
,
self
.
language_model
.
embedding
.
word_embeddings
.
weight
)
if
self
.
add_binary_head
:
binary_logits
=
self
.
binary_head
(
pooled_output
)
return
lm_logits
,
binary_logits
...
...
@@ -185,17 +139,13 @@ class BertModel(MegatronModule):
state_dict_
=
{}
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
not
self
.
add_ict_head
:
state_dict_
[
self
.
_lm_head_key
]
\
=
self
.
lm_head
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
destination
,
prefix
,
keep_vars
)
state_dict_
[
self
.
_lm_head_key
]
\
=
self
.
lm_head
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
add_binary_head
:
state_dict_
[
self
.
_binary_head_key
]
\
=
self
.
binary_head
.
state_dict
(
destination
,
prefix
,
keep_vars
)
elif
self
.
add_ict_head
:
state_dict_
[
self
.
_ict_head_key
]
\
=
self
.
ict_head
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
...
...
@@ -203,14 +153,10 @@ class BertModel(MegatronModule):
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
if
not
self
.
add_ict_head
:
self
.
lm_head
.
load_state_dict
(
state_dict
[
self
.
_lm_head_key
],
strict
=
strict
)
self
.
lm_head
.
load_state_dict
(
state_dict
[
self
.
_lm_head_key
],
strict
=
strict
)
if
self
.
add_binary_head
:
self
.
binary_head
.
load_state_dict
(
state_dict
[
self
.
_binary_head_key
],
strict
=
strict
)
elif
self
.
add_ict_head
:
self
.
ict_head
.
load_state_dict
(
state_dict
[
self
.
_ict_head_key
],
strict
=
strict
)
megatron/model/realm_model.py
View file @
6495927e
...
...
@@ -6,6 +6,13 @@ from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoi
from
megatron.model
import
BertModel
from
megatron.module
import
MegatronModule
from
megatron
import
mpu
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
from
megatron.model.language_model
import
get_language_model
from
megatron.model.utils
import
scaled_init_method_normal
from
megatron.model.utils
import
bert_attention_mask_func
from
megatron.model.utils
import
bert_extended_attention_mask
from
megatron.model.utils
import
bert_position_ids
class
ICTBertModel
(
MegatronModule
):
...
...
@@ -17,10 +24,9 @@ class ICTBertModel(MegatronModule):
only_query_model
=
False
,
only_block_model
=
False
):
super
(
ICTBertModel
,
self
).
__init__
()
bert_args
=
dict
(
num_tokentypes
=
num_tokentypes
,
add_binary_head
=
False
,
bert_kwargs
=
dict
(
ict_head_size
=
ict_head_size
,
num_tokentypes
=
num_tokentypes
,
parallel_output
=
parallel_output
)
assert
not
(
only_block_model
and
only_query_model
)
...
...
@@ -29,12 +35,12 @@ class ICTBertModel(MegatronModule):
if
self
.
use_query_model
:
# this model embeds (pseudo-)queries - Embed_input in the paper
self
.
query_model
=
BertModel
(
**
bert_args
)
self
.
query_model
=
IREncoder
BertModel
(
**
bert_
kw
args
)
self
.
_query_key
=
'question_model'
if
self
.
use_block_model
:
# this model embeds evidence blocks - Embed_doc in the paper
self
.
block_model
=
BertModel
(
**
bert_args
)
self
.
block_model
=
IREncoder
BertModel
(
**
bert_
kw
args
)
self
.
_block_key
=
'context_model'
def
forward
(
self
,
query_tokens
,
query_attention_mask
,
block_tokens
,
block_attention_mask
):
...
...
@@ -116,3 +122,64 @@ class ICTBertModel(MegatronModule):
# give each model the same ict_head to begin with as well
query_ict_head_state_dict
=
self
.
state_dict_for_save_checkpoint
()[
self
.
_query_key
][
'ict_head'
]
self
.
block_model
.
ict_head
.
load_state_dict
(
query_ict_head_state_dict
)
class
IREncoderBertModel
(
MegatronModule
):
"""Bert Language model."""
def
__init__
(
self
,
ict_head_size
,
num_tokentypes
=
2
,
parallel_output
=
True
):
super
(
IREncoderBertModel
,
self
).
__init__
()
args
=
get_args
()
self
.
ict_head_size
=
ict_head_size
self
.
parallel_output
=
parallel_output
init_method
=
init_method_normal
(
args
.
init_method_std
)
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
attention_mask_func
=
bert_attention_mask_func
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
True
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
)
self
.
ict_head
=
get_linear_layer
(
args
.
hidden_size
,
ict_head_size
,
init_method
)
self
.
_ict_head_key
=
'ict_head'
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
extended_attention_mask
=
bert_extended_attention_mask
(
attention_mask
,
next
(
self
.
language_model
.
parameters
()).
dtype
)
position_ids
=
bert_position_ids
(
input_ids
)
lm_output
,
pooled_output
=
self
.
language_model
(
input_ids
,
position_ids
,
extended_attention_mask
,
tokentype_ids
=
tokentype_ids
)
# Output.
if
self
.
add_ict_head
:
ict_logits
=
self
.
ict_head
(
pooled_output
)
return
ict_logits
,
None
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_
=
{}
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
state_dict_
[
self
.
_ict_head_key
]
\
=
self
.
ict_head
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
self
.
ict_head
.
load_state_dict
(
state_dict
[
self
.
_ict_head_key
],
strict
=
strict
)
megatron/model/utils.py
View file @
6495927e
...
...
@@ -78,3 +78,42 @@ def get_params_for_weight_decay_optimization(module):
if
p
is
not
None
and
n
==
'bias'
])
return
weight_decay_params
,
no_weight_decay_params
def
bert_attention_mask_func
(
attention_scores
,
attention_mask
):
attention_scores
=
attention_scores
+
attention_mask
return
attention_scores
def
bert_extended_attention_mask
(
attention_mask
,
dtype
):
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
attention_mask_b1s
=
attention_mask
.
unsqueeze
(
1
)
# [b, s, 1]
attention_mask_bs1
=
attention_mask
.
unsqueeze
(
2
)
# [b, s, s]
attention_mask_bss
=
attention_mask_b1s
*
attention_mask_bs1
# [b, 1, s, s]
extended_attention_mask
=
attention_mask_bss
.
unsqueeze
(
1
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0
# for masked positions, this operation will create a tensor which is
# 0.0 for positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
# fp16 compatibility
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
dtype
)
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
return
extended_attention_mask
def
bert_position_ids
(
token_ids
):
# Create position ids
seq_length
=
token_ids
.
size
(
1
)
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
token_ids
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
token_ids
)
return
position_ids
megatron/tokenizer/tokenizer.py
View file @
6495927e
...
...
@@ -20,6 +20,7 @@ from abc import abstractmethod
from
.bert_tokenization
import
FullTokenizer
as
FullBertTokenizer
from
.gpt2_tokenization
import
GPT2Tokenizer
from
megatron.data.realm_dataset_utils
import
join_str_list
def
build_tokenizer
(
args
):
...
...
@@ -155,6 +156,13 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
text_tokens
=
self
.
tokenizer
.
tokenize
(
text
)
return
self
.
tokenizer
.
convert_tokens_to_ids
(
text_tokens
)
def
decode_token_ids
(
self
,
token_ids
):
tokens
=
self
.
tokenizer
.
convert_ids_to_tokens
(
token_ids
)
exclude_list
=
[
'[PAD]'
,
'[CLS]'
]
non_pads
=
[
t
for
t
in
tokens
if
t
not
in
exclude_list
]
joined_strs
=
join_str_list
(
non_pads
)
return
joined_strs
@
property
def
cls
(
self
):
return
self
.
cls_id
...
...
megatron/training.py
View file @
6495927e
...
...
@@ -218,9 +218,10 @@ def setup_model_and_optimizer(model_provider_func):
else
:
args
.
iteration
=
0
if
args
.
iteration
==
0
and
isinstance
(
model
.
module
.
module
,
ICTBertModel
):
unwrapped_model
=
model
.
module
.
module
if
args
.
iteration
==
0
and
hasattr
(
unwrapped_model
,
'init_state_dict_from_bert'
):
print
(
"Initializing ICT from pretrained BERT model"
,
flush
=
True
)
model
.
module
.
module
.
init_state_dict_from_bert
()
unwrapped_model
.
init_state_dict_from_bert
()
return
model
,
optimizer
,
lr_scheduler
...
...
pretrain_bert_ict.py
View file @
6495927e
...
...
@@ -31,14 +31,17 @@ from megatron.utils import reduce_losses
num_batches
=
0
def
model_provider
(
only_query_model
=
False
,
only_block_model
=
False
):
def
general_
model_provider
(
only_query_model
=
False
,
only_block_model
=
False
):
"""Build the model."""
args
=
get_args
()
print_rank_0
(
'building BERT models ...'
)
if
args
.
ict_head_size
is
None
:
raise
ValueError
(
"Need to specify --ict-head-size to provide an ICTBertModel"
)
print_rank_0
(
'building ICTBertModel...'
)
# simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes
model
=
ICTBertModel
(
ict_head_size
=
128
,
ict_head_size
=
args
.
ict_head_size
,
num_tokentypes
=
2
,
parallel_output
=
True
,
only_query_model
=
only_query_model
,
...
...
@@ -47,6 +50,10 @@ def model_provider(only_query_model=False, only_block_model=False):
return
model
def
model_provider
():
return
general_model_provider
(
False
,
False
)
def
get_batch
(
data_iterator
):
# Items and their type.
keys
=
[
'query_tokens'
,
'query_pad_mask'
,
...
...
@@ -98,6 +105,7 @@ def forward_step(data_iterator, model):
all_query_logits
[
args
.
rank
*
batch_size
:(
args
.
rank
+
1
)
*
batch_size
]
=
query_logits
all_block_logits
[
args
.
rank
*
batch_size
:(
args
.
rank
+
1
)
*
batch_size
]
=
block_logits
# currently this assumes model parallel size == 1.
dist
.
all_reduce
(
all_query_logits
)
dist
.
all_reduce
(
all_block_logits
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment