Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
d7022c72
Commit
d7022c72
authored
Apr 24, 2020
by
Neel Kant
Browse files
Mostly debugged realm-mlm
parent
6f54f50f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
36 additions
and
20 deletions
+36
-20
hashed_index.py
hashed_index.py
+1
-1
megatron/data/bert_dataset.py
megatron/data/bert_dataset.py
+1
-1
megatron/data/ict_dataset.py
megatron/data/ict_dataset.py
+1
-1
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+3
-3
megatron/model/bert_model.py
megatron/model/bert_model.py
+26
-12
megatron/training.py
megatron/training.py
+1
-2
pretrain_realm.py
pretrain_realm.py
+3
-0
No files found.
hashed_index.py
View file @
d7022c72
...
@@ -198,7 +198,7 @@ def load_ict_checkpoint():
...
@@ -198,7 +198,7 @@ def load_ict_checkpoint():
if
isinstance
(
model
,
torchDDP
):
if
isinstance
(
model
,
torchDDP
):
model
=
model
.
module
model
=
model
.
module
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
load
)
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
ict_
load
)
with
open
(
tracker_filename
,
'r'
)
as
f
:
with
open
(
tracker_filename
,
'r'
)
as
f
:
iteration
=
int
(
f
.
read
().
strip
())
iteration
=
int
(
f
.
read
().
strip
())
...
...
megatron/data/bert_dataset.py
View file @
d7022c72
...
@@ -27,7 +27,6 @@ from megatron import mpu
...
@@ -27,7 +27,6 @@ from megatron import mpu
from
megatron.data.dataset_utils
import
build_training_sample
from
megatron.data.dataset_utils
import
build_training_sample
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
from
megatron.data.ict_dataset
import
InverseClozeDataset
from
megatron.data.ict_dataset
import
InverseClozeDataset
from
megatron.data.realm_dataset
import
RealmDataset
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
DATASET_TYPES
=
[
'standard_bert'
,
'ict'
,
'realm'
]
DATASET_TYPES
=
[
'standard_bert'
,
'ict'
,
'realm'
]
...
@@ -76,6 +75,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -76,6 +75,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
print_split_stats
(
'test'
,
2
)
print_split_stats
(
'test'
,
2
)
def
build_dataset
(
index
,
name
):
def
build_dataset
(
index
,
name
):
from
megatron.data.realm_dataset
import
RealmDataset
dataset
=
None
dataset
=
None
if
splits
[
index
+
1
]
>
splits
[
index
]:
if
splits
[
index
+
1
]
>
splits
[
index
]:
# Get the pointer to the original doc-idx so we can set it later.
# Get the pointer to the original doc-idx so we can set it later.
...
...
megatron/data/ict_dataset.py
View file @
d7022c72
...
@@ -90,7 +90,7 @@ class InverseClozeDataset(Dataset):
...
@@ -90,7 +90,7 @@ class InverseClozeDataset(Dataset):
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
(
3
+
len
(
title
))]
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
(
3
+
len
(
title
))]
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
return
block_tokens
,
block_pad_mask
return
(
block_tokens
,
block_pad_mask
)
def
concat_and_pad_tokens
(
self
,
tokens
,
title
=
None
):
def
concat_and_pad_tokens
(
self
,
tokens
,
title
=
None
):
"""concat with special tokens and pad sequence to self.max_seq_length"""
"""concat with special tokens and pad sequence to self.max_seq_length"""
...
...
megatron/data/realm_dataset.py
View file @
d7022c72
...
@@ -7,8 +7,8 @@ from megatron import get_tokenizer
...
@@ -7,8 +7,8 @@ from megatron import get_tokenizer
from
megatron.data.bert_dataset
import
BertDataset
,
get_samples_mapping_
from
megatron.data.bert_dataset
import
BertDataset
,
get_samples_mapping_
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
pad_and_convert_to_numpy
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
pad_and_convert_to_numpy
qa_nlp
=
spacy
.
load
(
'en_core_web_lg'
)
#
qa_nlp = spacy.load('en_core_web_lg')
qa_nlp
=
None
class
RealmDataset
(
BertDataset
):
class
RealmDataset
(
BertDataset
):
"""Dataset containing simple masked sentences for masked language modeling.
"""Dataset containing simple masked sentences for masked language modeling.
...
@@ -47,7 +47,7 @@ def build_simple_training_sample(sample, target_seq_length, max_seq_length,
...
@@ -47,7 +47,7 @@ def build_simple_training_sample(sample, target_seq_length, max_seq_length,
masked_labels
,
pad_id
,
max_seq_length
)
masked_labels
,
pad_id
,
max_seq_length
)
# REALM true sequence length is twice as long but none of that is to be predicted with LM
# REALM true sequence length is twice as long but none of that is to be predicted with LM
loss_mask_np
=
np
.
concatenate
((
loss_mask_np
,
np
.
ones
(
loss_mask_np
.
shape
)),
-
1
)
loss_mask_np
=
np
.
concatenate
((
loss_mask_np
,
np
.
ones
(
loss_mask_np
.
shape
)),
-
1
)
.
astype
(
np
.
int64
)
train_sample
=
{
train_sample
=
{
'tokens'
:
tokens_np
,
'tokens'
:
tokens_np
,
...
...
megatron/model/bert_model.py
View file @
d7022c72
...
@@ -234,22 +234,35 @@ class REALMBertModel(MegatronModule):
...
@@ -234,22 +234,35 @@ class REALMBertModel(MegatronModule):
def
forward
(
self
,
tokens
,
attention_mask
):
def
forward
(
self
,
tokens
,
attention_mask
):
# [batch_size x 5 x seq_length]
# [batch_size x 5 x seq_length]
top5_block_tokens
,
top5_block_attention_mask
=
self
.
retriever
.
retrieve_evidence_blocks
(
tokens
,
attention_mask
)
top5_block_tokens
,
top5_block_attention_mask
=
self
.
retriever
.
retrieve_evidence_blocks
(
tokens
,
attention_mask
)
batch_size
=
tokens
.
shape
[
0
]
seq_length
=
top5_block_tokens
.
shape
[
2
]
top5_block_tokens
=
torch
.
cuda
.
LongTensor
(
top5_block_tokens
).
reshape
(
-
1
,
seq_length
)
top5_block_attention_mask
=
torch
.
cuda
.
LongTensor
(
top5_block_attention_mask
).
reshape
(
-
1
,
seq_length
)
# [batch_size x 5 x embed_size]
fresh_block_logits
=
self
.
retriever
.
ict_model
.
module
.
module
.
embed_block
(
top5_block_tokens
,
top5_block_attention_mask
).
reshape
(
batch_size
,
5
,
-
1
)
# [batch_size x embed_size x 1]
query_logits
=
self
.
retriever
.
ict_model
.
module
.
module
.
embed_query
(
tokens
,
attention_mask
).
unsqueeze
(
2
)
# [batch_size x 5]
# [batch_size x 5]
fresh_block_
logits
=
self
.
retriever
.
ict_model
.
embed_block
(
top5_block_tokens
,
top5_block_attention_mask
)
fresh_block_
scores
=
torch
.
matmul
(
fresh_block_logits
,
query_logits
).
squeeze
(
)
block_probs
=
F
.
softmax
(
fresh_block_
logits
,
axis
=
1
)
block_probs
=
F
.
softmax
(
fresh_block_
scores
,
dim
=
1
)
# [batch_size
x
5 x seq_length]
# [batch_size
*
5 x seq_length]
tokens
=
torch
.
stack
([
tokens
.
unsqueeze
(
1
)]
*
5
,
dim
=
1
)
tokens
=
torch
.
stack
([
tokens
.
unsqueeze
(
1
)]
*
5
,
dim
=
1
)
.
reshape
(
-
1
,
seq_length
)
attention_mask
=
torch
.
stack
([
attention_mask
.
unsqueeze
(
1
)]
*
5
,
dim
=
1
)
attention_mask
=
torch
.
stack
([
attention_mask
.
unsqueeze
(
1
)]
*
5
,
dim
=
1
)
.
reshape
(
-
1
,
seq_length
)
# [batch_size
x
5 x 2 * seq_length]
# [batch_size
*
5 x 2 * seq_length]
all_tokens
=
torch
.
cat
((
tokens
,
top5_block_tokens
),
axis
=
2
)
all_tokens
=
torch
.
cat
((
tokens
,
top5_block_tokens
),
axis
=
1
)
all_attention_mask
=
torch
.
cat
((
attention_mask
,
top5_block_attention_mask
),
axis
=
2
)
all_attention_mask
=
torch
.
cat
((
attention_mask
,
top5_block_attention_mask
),
axis
=
1
)
all_token_types
=
torch
.
zeros
(
all_tokens
.
shape
).
type
(
torch
.
int64
).
cuda
()
all_token_types
=
torch
.
zeros
(
all_tokens
.
shape
).
type
(
torch
.
int64
).
cuda
()
# [batch_size x 5 x 2 * seq_length x vocab_size]
# [batch_size x 5 x 2 * seq_length x vocab_size]
lm_logits
,
_
=
self
.
lm_model
.
forward
(
all_tokens
,
all_attention_mask
,
all_token_types
)
lm_logits
,
_
=
self
.
lm_model
.
forward
(
all_tokens
,
all_attention_mask
,
all_token_types
)
lm_logits
=
lm_logits
.
reshape
(
batch_size
,
5
,
2
*
seq_length
,
-
1
)
return
lm_logits
,
block_probs
return
lm_logits
,
block_probs
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
...
@@ -263,7 +276,7 @@ class REALMBertModel(MegatronModule):
...
@@ -263,7 +276,7 @@ class REALMBertModel(MegatronModule):
class
REALMRetriever
(
MegatronModule
):
class
REALMRetriever
(
MegatronModule
):
"""Retriever which uses a pretrained ICTBertModel and a
h
ashed
_i
ndex"""
"""Retriever which uses a pretrained ICTBertModel and a
H
ashed
I
ndex"""
def
__init__
(
self
,
ict_model
,
ict_dataset
,
hashed_index
,
top_k
=
5
):
def
__init__
(
self
,
ict_model
,
ict_dataset
,
hashed_index
,
top_k
=
5
):
super
(
REALMRetriever
,
self
).
__init__
()
super
(
REALMRetriever
,
self
).
__init__
()
self
.
ict_model
=
ict_model
self
.
ict_model
=
ict_model
...
@@ -301,13 +314,14 @@ class REALMRetriever(MegatronModule):
...
@@ -301,13 +314,14 @@ class REALMRetriever(MegatronModule):
top5_start_end_doc
=
[
bucket
[
idx
][:
3
]
for
idx
in
top5_indices
.
squeeze
()]
top5_start_end_doc
=
[
bucket
[
idx
][:
3
]
for
idx
in
top5_indices
.
squeeze
()]
# top_k tuples of (block_tokens, block_pad_mask)
# top_k tuples of (block_tokens, block_pad_mask)
top5_block_data
=
[(
self
.
ict_dataset
.
get_block
(
*
indices
))
for
indices
in
top5_start_end_doc
]
top5_block_data
=
[
self
.
ict_dataset
.
get_block
(
*
indices
)
for
indices
in
top5_start_end_doc
]
top5_tokens
,
top5_pad_masks
=
zip
(
top5_block_data
)
top5_tokens
,
top5_pad_masks
=
zip
(
*
top5_block_data
)
all_top5_tokens
.
append
(
np
.
array
(
top5_tokens
))
all_top5_tokens
.
append
(
np
.
array
(
top5_tokens
))
all_top5_pad_masks
.
append
(
np
.
array
(
top5_pad_masks
))
all_top5_pad_masks
.
append
(
np
.
array
(
top5_pad_masks
))
return
all_top5_tokens
,
all_top5_pad_masks
return
np
.
array
(
all_top5_tokens
)
,
np
.
array
(
all_top5_pad_masks
)
class
ICTBertModel
(
MegatronModule
):
class
ICTBertModel
(
MegatronModule
):
...
...
megatron/training.py
View file @
d7022c72
...
@@ -225,6 +225,7 @@ def backward_step(optimizer, model, loss):
...
@@ -225,6 +225,7 @@ def backward_step(optimizer, model, loss):
"""Backward step."""
"""Backward step."""
args
=
get_args
()
args
=
get_args
()
timers
=
get_timers
()
timers
=
get_timers
()
print
(
"start backward"
,
flush
=
True
)
# Backward pass.
# Backward pass.
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
...
@@ -239,11 +240,9 @@ def backward_step(optimizer, model, loss):
...
@@ -239,11 +240,9 @@ def backward_step(optimizer, model, loss):
model
.
allreduce_params
(
reduce_after
=
False
,
model
.
allreduce_params
(
reduce_after
=
False
,
fp32_allreduce
=
args
.
fp32_allreduce
)
fp32_allreduce
=
args
.
fp32_allreduce
)
timers
(
'allreduce'
).
stop
()
timers
(
'allreduce'
).
stop
()
# Update master gradients.
# Update master gradients.
if
args
.
fp16
:
if
args
.
fp16
:
optimizer
.
update_master_grads
()
optimizer
.
update_master_grads
()
# Clipping gradients helps prevent the exploding gradient.
# Clipping gradients helps prevent the exploding gradient.
if
args
.
clip_grad
>
0
:
if
args
.
clip_grad
>
0
:
if
not
args
.
fp16
:
if
not
args
.
fp16
:
...
...
pretrain_realm.py
View file @
d7022c72
...
@@ -74,6 +74,7 @@ def forward_step(data_iterator, model):
...
@@ -74,6 +74,7 @@ def forward_step(data_iterator, model):
# Get the batch.
# Get the batch.
timers
(
'batch generator'
).
start
()
timers
(
'batch generator'
).
start
()
tokens
,
labels
,
loss_mask
,
pad_mask
=
get_batch
(
data_iterator
)
tokens
,
labels
,
loss_mask
,
pad_mask
=
get_batch
(
data_iterator
)
labels
=
torch
.
cat
((
labels
,
labels
),
axis
=-
1
)
timers
(
'batch generator'
).
stop
()
timers
(
'batch generator'
).
stop
()
# Forward model.
# Forward model.
...
@@ -81,6 +82,7 @@ def forward_step(data_iterator, model):
...
@@ -81,6 +82,7 @@ def forward_step(data_iterator, model):
lm_logits
,
block_probs
=
model
(
tokens
,
pad_mask
)
lm_logits
,
block_probs
=
model
(
tokens
,
pad_mask
)
# P(y|x) = sum_z(P(y|z, x) * P(z|x))
# P(y|x) = sum_z(P(y|z, x) * P(z|x))
block_probs
=
block_probs
.
unsqueeze
(
2
).
unsqueeze
(
3
).
expand_as
(
lm_logits
)
lm_logits
=
torch
.
sum
(
lm_logits
*
block_probs
,
dim
=
1
)
lm_logits
=
torch
.
sum
(
lm_logits
*
block_probs
,
dim
=
1
)
lm_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
contiguous
().
float
(),
lm_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
contiguous
().
float
(),
labels
.
contiguous
())
labels
.
contiguous
())
...
@@ -88,6 +90,7 @@ def forward_step(data_iterator, model):
...
@@ -88,6 +90,7 @@ def forward_step(data_iterator, model):
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
reduced_loss
=
reduce_losses
([
lm_loss
])
reduced_loss
=
reduce_losses
([
lm_loss
])
print
(
reduced_loss
,
flush
=
True
)
return
lm_loss
,
{
'lm_loss'
:
reduced_loss
[
0
]}
return
lm_loss
,
{
'lm_loss'
:
reduced_loss
[
0
]}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment