Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
ee2490d5
Commit
ee2490d5
authored
Apr 20, 2020
by
Neel Kant
Browse files
Start creating REALMBertModel
parent
2d98cfbf
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
197 additions
and
4 deletions
+197
-4
megatron/data/ict_dataset.py
megatron/data/ict_dataset.py
+6
-0
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+22
-1
megatron/model/__init__.py
megatron/model/__init__.py
+1
-1
megatron/model/bert_model.py
megatron/model/bert_model.py
+48
-2
pretrain_realm.py
pretrain_realm.py
+120
-0
No files found.
megatron/data/ict_dataset.py
View file @
ee2490d5
...
@@ -89,6 +89,12 @@ class InverseClozeDataset(Dataset):
...
@@ -89,6 +89,12 @@ class InverseClozeDataset(Dataset):
token_types
=
[
0
]
*
self
.
max_seq_length
token_types
=
[
0
]
*
self
.
max_seq_length
return
tokens
,
token_types
,
pad_mask
return
tokens
,
token_types
,
pad_mask
def
get_block
(
self
,
start_idx
,
end_idx
,
doc_idx
,
block_idx
):
block
=
[
self
.
context_dataset
[
i
]
for
i
in
range
(
start_idx
,
end_idx
)]
title
=
list
(
self
.
titles_dataset
[
int
(
doc_idx
)])
block
=
list
(
itertools
.
chain
(
*
block
))[
self
.
max_seq_length
-
(
3
+
len
(
title
))]
def
get_samples_mapping
(
self
,
data_prefix
,
num_epochs
,
max_num_samples
):
def
get_samples_mapping
(
self
,
data_prefix
,
num_epochs
,
max_num_samples
):
if
not
num_epochs
:
if
not
num_epochs
:
if
not
max_num_samples
:
if
not
max_num_samples
:
...
...
megatron/data/realm_dataset.py
View file @
ee2490d5
import
numpy
as
np
import
numpy
as
np
import
spacy
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron.data.bert_dataset
import
get_samples_mapping_
from
megatron.data.bert_dataset
import
get_samples_mapping_
from
megatron.data.dataset_utils
import
build_simple_training_sample
from
megatron.data.dataset_utils
import
build_simple_training_sample
qa_nlp
=
spacy
.
load
(
'en_core_web_lg'
)
class
RealmDataset
(
Dataset
):
class
RealmDataset
(
Dataset
):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
"""Dataset containing simple masked sentences for masked language modeling.
The dataset should yield sentences just like the regular BertDataset
However, this dataset also needs to be able to return a set of blocks
given their start and end indices.
Presumably
"""
def
__init__
(
self
,
name
,
indexed_dataset
,
data_prefix
,
def
__init__
(
self
,
name
,
indexed_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
max_seq_length
,
short_seq_prob
,
seed
):
max_seq_length
,
short_seq_prob
,
seed
):
...
@@ -58,3 +68,14 @@ class RealmDataset(Dataset):
...
@@ -58,3 +68,14 @@ class RealmDataset(Dataset):
self
.
mask_id
,
self
.
pad_id
,
self
.
mask_id
,
self
.
pad_id
,
self
.
masked_lm_prob
,
np_rng
)
self
.
masked_lm_prob
,
np_rng
)
def
spacy_ner
(
block_text
):
candidates
=
{}
block
=
qa_nlp
(
block_text
)
starts
=
[]
answers
=
[]
for
ent
in
block
.
ents
:
starts
.
append
(
int
(
ent
.
start_char
))
answers
.
append
(
str
(
ent
.
text
))
candidates
[
'starts'
]
=
starts
candidates
[
'answers'
]
=
answers
megatron/model/__init__.py
View file @
ee2490d5
...
@@ -14,6 +14,6 @@
...
@@ -14,6 +14,6 @@
# limitations under the License.
# limitations under the License.
from
.distributed
import
*
from
.distributed
import
*
from
.bert_model
import
BertModel
,
ICTBertModel
from
.bert_model
import
BertModel
,
ICTBertModel
,
REALMBertModel
from
.gpt2_model
import
GPT2Model
from
.gpt2_model
import
GPT2Model
from
.utils
import
get_params_for_weight_decay_optimization
from
.utils
import
get_params_for_weight_decay_optimization
megatron/model/bert_model.py
View file @
ee2490d5
...
@@ -214,8 +214,49 @@ class BertModel(MegatronModule):
...
@@ -214,8 +214,49 @@ class BertModel(MegatronModule):
state_dict
[
self
.
_ict_head_key
],
strict
=
strict
)
state_dict
[
self
.
_ict_head_key
],
strict
=
strict
)
# REALMBertModel is just BertModel without binary head.
class
REALMBertModel
(
MegatronModule
):
# needs a different kind of dataset though
def
__init__
(
self
,
ict_model_path
,
block_hash_data_path
):
super
(
REALMBertModel
,
self
).
__init__
()
bert_args
=
dict
(
num_tokentypes
=
2
,
add_binary_head
=
False
,
parallel_output
=
True
)
self
.
lm_model
=
BertModel
(
**
bert_args
)
self
.
_lm_key
=
'realm_lm'
self
.
ict_model
=
ict_model
self
.
ict_dataset
=
ict_dataset
self
.
block_hash_data
=
block_hash_data
def
forward
(
self
,
tokens
,
attention_mask
,
token_types
):
# [batch_size x embed_size]
query_logits
=
self
.
ict_model
.
embed_query
(
tokens
,
attention_mask
,
token_types
)
hash_matrix_pos
=
self
.
hash_data
[
'matrix'
]
# [batch_size, num_buckets / 2]
query_hash_pos
=
torch
.
matmul
(
query_logits
,
hash_matrix_pos
)
query_hash_full
=
torch
.
cat
((
query_hash_pos
,
-
query_hash_pos
),
axis
=
1
)
# [batch_size]
query_hashes
=
torch
.
argmax
(
query_hash_full
,
axis
=
1
)
batch_block_embeds
=
[]
for
hash
in
query_hashes
:
# TODO: this should be made into a single np.array in preprocessing
bucket_blocks
=
self
.
hash_data
[
hash
]
block_indices
=
bucket_blocks
[:,
3
]
# [bucket_pop, embed_size]
block_embeds
=
[
self
.
block_data
[
idx
]
for
idx
in
block_indices
]
# will become [batch_size, bucket_pop, embed_size]
# will require padding to do tensor multiplication
batch_block_embeds
.
append
(
block_embeds
)
batch_block_embeds
=
np
.
array
(
batch_block_embeds
)
retrieval_scores
=
query_logits
.
matmul
(
torch
.
transpose
(
batch_block_embeds
,
0
,
1
))
class
ICTBertModel
(
MegatronModule
):
class
ICTBertModel
(
MegatronModule
):
...
@@ -249,6 +290,11 @@ class ICTBertModel(MegatronModule):
...
@@ -249,6 +290,11 @@ class ICTBertModel(MegatronModule):
return
query_logits
,
block_logits
return
query_logits
,
block_logits
def
embed_query
(
self
,
query_tokens
,
query_attention_mask
,
query_types
):
query_ict_logits
,
_
=
self
.
question_model
.
forward
(
query_tokens
,
1
-
query_attention_mask
,
query_types
)
return
query_ict_logits
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
"""Save dict with state dicts of each of the models."""
"""Save dict with state dicts of each of the models."""
...
...
pretrain_realm.py
0 → 100644
View file @
ee2490d5
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain BERT for Inverse Cloze Task"""
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron.data.bert_dataset
import
build_train_valid_test_datasets
from
megatron.model
import
ICTBertModel
,
REALMBertModel
from
megatron.training
import
pretrain
from
megatron.utils
import
reduce_losses
num_batches
=
0
def
model_provider
():
"""Build the model."""
args
=
get_args
()
print_rank_0
(
'building BERT models ...'
)
realm_model
=
REALMBertModel
(
args
.
ict_model_path
,
args
.
block_hash_data_path
)
return
ict_model
def
get_batch
(
data_iterator
):
# Items and their type.
keys
=
[
'query_tokens'
,
'query_types'
,
'query_pad_mask'
]
datatype
=
torch
.
int64
# Broadcast data.
if
data_iterator
is
None
:
data
=
None
else
:
data
=
next
(
data_iterator
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
query_tokens
=
data_b
[
'query_tokens'
].
long
()
query_types
=
data_b
[
'query_types'
].
long
()
query_pad_mask
=
data_b
[
'query_pad_mask'
].
long
()
return
query_tokens
,
query_types
,
query_pad_mask
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
timers
=
get_timers
()
# Get the batch.
timers
(
'batch generator'
).
start
()
query_tokens
,
query_types
,
query_pad_mask
=
get_batch
(
data_iterator
)
timers
(
'batch generator'
).
stop
()
# Forward model.
query_logits
,
block_logits
=
model
(
query_tokens
,
query_pad_mask
,
query_types
,
block_tokens
,
block_pad_mask
,
block_types
).
float
()
# [batch x h] * [h x batch]
retrieval_scores
=
query_logits
.
matmul
(
torch
.
transpose
(
block_logits
,
0
,
1
))
softmaxed
=
F
.
softmax
(
retrieval_scores
,
dim
=
1
)
top5_vals
,
top5_indices
=
torch
.
topk
(
softmaxed
,
k
=
5
,
sorted
=
True
)
batch_size
=
softmaxed
.
shape
[
0
]
top1_acc
=
torch
.
cuda
.
FloatTensor
([
sum
([
int
(
top5_indices
[
i
,
0
]
==
i
)
for
i
in
range
(
batch_size
)])
/
batch_size
])
top5_acc
=
torch
.
cuda
.
FloatTensor
([
sum
([
int
(
i
in
top5_indices
[
i
])
for
i
in
range
(
batch_size
)])
/
batch_size
])
retrieval_loss
=
F
.
cross_entropy
(
softmaxed
,
torch
.
arange
(
batch_size
).
cuda
())
reduced_losses
=
reduce_losses
([
retrieval_loss
,
top1_acc
,
top5_acc
])
return
retrieval_loss
,
{
'retrieval loss'
:
reduced_losses
[
0
],
'top1_acc'
:
reduced_losses
[
1
],
'top5_acc'
:
reduced_losses
[
2
]}
def
train_valid_test_datasets_provider
(
train_val_test_num_samples
):
"""Build train, valid and test datasets."""
args
=
get_args
()
print_rank_0
(
'> building train, validation, and test datasets '
'for BERT ...'
)
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets
(
data_prefix
=
args
.
data_path
,
data_impl
=
args
.
data_impl
,
splits_string
=
args
.
split
,
train_valid_test_num_samples
=
train_val_test_num_samples
,
max_seq_length
=
args
.
seq_length
,
masked_lm_prob
=
args
.
mask_prob
,
short_seq_prob
=
args
.
short_seq_prob
,
seed
=
args
.
seed
,
skip_warmup
=
(
not
args
.
mmap_warmup
),
ict_dataset
=
True
)
print_rank_0
(
"> finished creating BERT ICT datasets ..."
)
return
train_ds
,
valid_ds
,
test_ds
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment