Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
f7f730e1
Commit
f7f730e1
authored
Apr 23, 2020
by
Neel Kant
Browse files
Write pretrain_realm.py and misc dataset_type left from earlier
parent
f42b4d24
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
56 deletions
+30
-56
pretrain_bert_ict.py
pretrain_bert_ict.py
+1
-1
pretrain_realm.py
pretrain_realm.py
+29
-55
No files found.
pretrain_bert_ict.py
View file @
f7f730e1
...
@@ -113,7 +113,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
...
@@ -113,7 +113,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
short_seq_prob
=
args
.
short_seq_prob
,
short_seq_prob
=
args
.
short_seq_prob
,
seed
=
args
.
seed
,
seed
=
args
.
seed
,
skip_warmup
=
(
not
args
.
mmap_warmup
),
skip_warmup
=
(
not
args
.
mmap_warmup
),
ict_
dataset
=
True
)
dataset
_type
=
'ict'
)
print_rank_0
(
"> finished creating BERT ICT datasets ..."
)
print_rank_0
(
"> finished creating BERT ICT datasets ..."
)
return
train_ds
,
valid_ds
,
test_ds
return
train_ds
,
valid_ds
,
test_ds
...
...
pretrain_realm.py
View file @
f7f730e1
...
@@ -17,18 +17,16 @@
...
@@ -17,18 +17,16 @@
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
hashed_index
import
HashedIndex
,
load_ict_checkpoint
,
get_ict_dataset
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.data.bert_dataset
import
build_train_valid_test_datasets
from
megatron.data.bert_dataset
import
build_train_valid_test_datasets
from
megatron.model
import
ICT
BertModel
,
REALM
BertModel
from
megatron.model
import
REALM
BertModel
,
REALM
Retriever
from
megatron.training
import
get_model
,
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
reduce_losses
from
pretrain_bert_ict
import
model_provider
as
ict_model_provider
num_batches
=
0
num_batches
=
0
...
@@ -36,39 +34,21 @@ num_batches = 0
...
@@ -36,39 +34,21 @@ num_batches = 0
def
model_provider
():
def
model_provider
():
"""Build the model."""
"""Build the model."""
args
=
get_args
()
args
=
get_args
()
print_rank_0
(
'building
BERT
models ...'
)
print_rank_0
(
'building
REALM
models ...'
)
ict_model
=
get_model
(
ict_model_provider
)
ict_model
=
load_ict_checkpoint
()
ict_dataset
=
get_ict_dataset
()
hashed_index
=
HashedIndex
.
load_from_file
(
'block_hash_data.pkl'
)
if
isinstance
(
ict_model
,
torchDDP
):
retriever
=
REALMRetriever
(
ict_model
,
ict_dataset
,
hashed_index
)
model
=
ict_model
.
module
model
=
REALMBertModel
(
retriever
)
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
load
)
with
open
(
tracker_filename
,
'r'
)
as
f
:
iteration
=
int
(
f
.
read
().
strip
())
assert
iteration
>
0
return
model
checkpoint_name
=
get_checkpoint_name
(
args
.
load
,
iteration
,
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
ict_model
.
load_state_dict
(
state_dict
[
'model'
])
torch
.
distributed
.
barrier
()
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
realm_model
=
REALMBertModel
(
ict_model
,
args
.
block_hash_data_path
)
return
ict_model
def
get_batch
(
data_iterator
):
def
get_batch
(
data_iterator
):
# Items and their type.
# Items and their type.
keys
=
[
'
query_
tokens'
,
'
query_types'
,
'query_
pad_mask'
]
keys
=
[
'tokens'
,
'
labels'
,
'loss_mask'
,
'
pad_mask'
]
datatype
=
torch
.
int64
datatype
=
torch
.
int64
# Broadcast data.
# Broadcast data.
...
@@ -79,11 +59,12 @@ def get_batch(data_iterator):
...
@@ -79,11 +59,12 @@ def get_batch(data_iterator):
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
# Unpack.
query_tokens
=
data_b
[
'query_tokens'
].
long
()
tokens
=
data_b
[
'tokens'
].
long
()
query_types
=
data_b
[
'query_types'
].
long
()
labels
=
data_b
[
'labels'
].
long
()
query_pad_mask
=
data_b
[
'query_pad_mask'
].
long
()
loss_mask
=
data_b
[
'loss_mask'
].
long
()
pad_mask
=
data_b
[
'pad_mask'
].
long
()
return
query_
tokens
,
query_types
,
query_
pad_mask
return
tokens
,
labels
,
loss_mask
,
pad_mask
def
forward_step
(
data_iterator
,
model
):
def
forward_step
(
data_iterator
,
model
):
...
@@ -92,29 +73,22 @@ def forward_step(data_iterator, model):
...
@@ -92,29 +73,22 @@ def forward_step(data_iterator, model):
# Get the batch.
# Get the batch.
timers
(
'batch generator'
).
start
()
timers
(
'batch generator'
).
start
()
query_
tokens
,
query_types
,
query_
pad_mask
=
get_batch
(
data_iterator
)
tokens
,
labels
,
loss_mask
,
pad_mask
=
get_batch
(
data_iterator
)
timers
(
'batch generator'
).
stop
()
timers
(
'batch generator'
).
stop
()
# Forward model.
# Forward model.
query_logits
,
block_logits
=
model
(
query_tokens
,
query_pad_mask
,
query_types
,
# TODO: MAKE SURE PAD IS NOT 1 - PAD
block_tokens
,
block_pad_mask
,
block_types
).
float
()
lm_logits
,
block_probs
=
model
(
tokens
,
pad_mask
)
# [batch x h] * [h x batch]
# P(y|x) = sum_z(P(y|z, x) * P(z|x))
retrieval_scores
=
query_logits
.
matmul
(
torch
.
transpose
(
block_logits
,
0
,
1
))
lm_logits
=
torch
.
sum
(
lm_logits
*
block_probs
,
dim
=
1
)
softmaxed
=
F
.
softmax
(
retrieval_scores
,
dim
=
1
)
lm_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
contiguous
().
float
(),
labels
.
contiguous
())
top5_vals
,
top5_indices
=
torch
.
topk
(
softmaxed
,
k
=
5
,
sorted
=
True
)
lm_loss
=
torch
.
sum
(
batch_size
=
softmaxed
.
shape
[
0
]
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
top1_acc
=
torch
.
cuda
.
FloatTensor
([
sum
([
int
(
top5_indices
[
i
,
0
]
==
i
)
for
i
in
range
(
batch_size
)])
/
batch_size
])
reduced_loss
=
reduce_losses
([
lm_loss
])
top5_acc
=
torch
.
cuda
.
FloatTensor
([
sum
([
int
(
i
in
top5_indices
[
i
])
for
i
in
range
(
batch_size
)])
/
batch_size
])
return
lm_loss
,
{
'lm_loss'
:
reduced_loss
[
0
]}
retrieval_loss
=
F
.
cross_entropy
(
softmaxed
,
torch
.
arange
(
batch_size
).
cuda
())
reduced_losses
=
reduce_losses
([
retrieval_loss
,
top1_acc
,
top5_acc
])
return
retrieval_loss
,
{
'retrieval loss'
:
reduced_losses
[
0
],
'top1_acc'
:
reduced_losses
[
1
],
'top5_acc'
:
reduced_losses
[
2
]}
def
train_valid_test_datasets_provider
(
train_val_test_num_samples
):
def
train_valid_test_datasets_provider
(
train_val_test_num_samples
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment