Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
32bb4edc
Commit
32bb4edc
authored
Jun 05, 2020
by
Neel Kant
Browse files
Prune changes to only be related to ICT
parent
674814a5
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
22 additions
and
637 deletions
+22
-637
megatron/model/realm_model.py
megatron/model/realm_model.py
+6
-296
megatron/mpu/data.py
megatron/mpu/data.py
+1
-1
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+0
-57
megatron/training.py
megatron/training.py
+9
-68
megatron/utils.py
megatron/utils.py
+3
-9
pretrain_bert_ict.py
pretrain_bert_ict.py
+3
-10
pretrain_realm.py
pretrain_realm.py
+0
-196
No files found.
megatron/model/realm_model.py
View file @
32bb4edc
import
os
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
,
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.data.realm_index
import
detach
,
BlockData
,
FaissMIPSIndex
from
megatron.model
import
BertModel
from
megatron.model.utils
import
get_linear_layer
,
init_method_normal
...
...
@@ -12,294 +13,6 @@ from megatron.utils import report_memory
from
megatron
import
mpu
class
REALMAnswerSpanModel
(
MegatronModule
):
def
__init__
(
self
,
realm_model
,
mlp_hidden_size
=
64
):
super
(
REALMAnswerSpanModel
,
self
).
__init__
()
self
.
realm_model
=
realm_model
self
.
mlp_hidden_size
=
mlp_hidden_size
args
=
get_args
()
init_method
=
init_method_normal
(
args
.
init_method_std
)
self
.
fc1
=
get_linear_layer
(
2
*
args
.
hidden_size
,
self
.
mlp_hidden_size
,
init_method
)
self
.
_fc1_key
=
'fc1'
self
.
fc2
=
get_linear_layer
(
self
.
mlp_hidden_size
,
1
,
init_method
)
self
.
_fc2_key
=
'fc2'
max_length
=
10
self
.
start_ends
=
[]
for
length
in
range
(
max_length
):
self
.
start_ends
.
extend
([(
i
,
i
+
length
)
for
i
in
range
(
288
-
length
)])
def
forward
(
self
,
question_tokens
,
question_attention_mask
,
answer_tokens
,
answer_token_lengths
):
lm_logits
,
block_probs
,
topk_block_tokens
=
self
.
realm_model
(
question_tokens
,
question_attention_mask
,
query_block_indices
=
None
,
return_topk_block_tokens
=
True
)
batch_span_reps
,
batch_loss_masks
=
[],
[]
# go through batch one-by-one
for
i
in
range
(
len
(
answer_token_lengths
)):
answer_length
=
answer_token_lengths
[
i
]
answer_span_tokens
=
answer_tokens
[
i
][:
answer_length
]
span_reps
,
loss_masks
=
[],
[]
# go through the top k for the batch item
for
logits
,
block_tokens
in
zip
(
lm_logits
[
i
],
topk_block_tokens
[
i
]):
block_logits
=
logits
[
len
(
logits
)
/
2
:]
span_starts
=
range
(
len
(
block_tokens
)
-
(
answer_length
-
1
))
# record the start, end indices of spans which match the answer
matching_indices
=
set
([
(
idx
,
idx
+
answer_length
-
1
)
for
idx
in
span_starts
if
np
.
array_equal
(
block_tokens
[
idx
:
idx
+
answer_length
],
answer_span_tokens
)
])
# create a mask for computing the loss on P(y | z, x)
# [num_spans]
loss_masks
.
append
(
torch
.
LongTensor
([
int
(
idx_pair
in
matching_indices
)
for
idx_pair
in
self
.
start_ends
]))
# get all of the candidate spans that need to be fed to MLP
# [num_spans x 2 * embed_size]
span_reps
.
append
([
torch
.
cat
((
block_logits
[
s
],
block_logits
[
e
]))
for
(
s
,
e
)
in
self
.
start_ends
])
# data for all k blocks for a single batch item
# [k x num_spans]
batch_loss_masks
.
append
(
torch
.
stack
(
loss_masks
))
# [k x num_spans x 2 * embed_size]
batch_span_reps
.
append
(
torch
.
stack
(
span_reps
))
# data for all batch items
# [batch_size x k x num_spans]
batch_loss_masks
=
torch
.
stack
(
batch_loss_masks
)
batch_span_reps
=
torch
.
stack
(
batch_span_reps
)
# [batch_size x k x num_spans]
batch_span_logits
=
self
.
fc2
(
self
.
fc1
(
batch_span_reps
)).
squeeze
()
return
batch_span_logits
,
batch_loss_masks
,
block_probs
# block_probs = block_probs.unsqueeze(2).unsqueeze(3).expand_as(lm_logits)
# lm_logits = torch.sum(lm_logits * block_probs, dim=1)
class
REALMBertModel
(
MegatronModule
):
def
__init__
(
self
,
retriever
):
super
(
REALMBertModel
,
self
).
__init__
()
bert_args
=
dict
(
num_tokentypes
=
2
,
add_binary_head
=
False
,
parallel_output
=
True
)
self
.
lm_model
=
BertModel
(
**
bert_args
)
load_checkpoint
(
self
.
lm_model
,
optimizer
=
None
,
lr_scheduler
=
None
)
self
.
_lm_key
=
'realm_lm'
self
.
retriever
=
retriever
self
.
top_k
=
self
.
retriever
.
top_k
self
.
_retriever_key
=
'retriever'
def
forward
(
self
,
tokens
,
attention_mask
,
query_block_indices
,
return_topk_block_tokens
=
False
):
# print("\nNEW FORWARD", '-' * 100, flush=True)
dset
=
self
.
retriever
.
ict_dataset
det_tokens
=
detach
(
tokens
)[
0
].
tolist
()
det_attention
=
detach
(
attention_mask
)[
0
].
tolist
()
# print("\nTokens: ", det_tokens, '\n', flush=True)
# print("\nAttention: ", det_attention, '\n', flush=True)
# print("pad id: ", dset.pad_id, flush=True)
assert
bool
(
0
in
det_attention
)
==
bool
(
dset
.
pad_id
in
det_tokens
)
if
0
in
det_attention
:
idx_padid
=
det_tokens
.
index
(
dset
.
pad_id
)
idx_attn
=
det_attention
.
index
(
0
)
assert
idx_padid
==
idx_attn
,
(
idx_padid
,
idx_attn
)
# text = dset.decode_tokens(det_tokens)
# print(text, flush=True)
# print("Token shape: ", tokens.shape, flush=True)
# [batch_size x k x seq_length]
topk_block_tokens
,
topk_block_attention_mask
=
self
.
retriever
.
retrieve_evidence_blocks
(
tokens
,
attention_mask
,
query_block_indices
=
query_block_indices
,
include_null_doc
=
True
)
# print("Top k block shape: ", topk_block_tokens.shape, flush=True)
batch_size
=
tokens
.
shape
[
0
]
# create a copy in case it needs to be returned
ret_topk_block_tokens
=
np
.
array
(
topk_block_tokens
)
seq_length
=
topk_block_tokens
.
shape
[
2
]
long_tensor
=
torch
.
cuda
.
LongTensor
topk_block_tokens
=
long_tensor
(
topk_block_tokens
).
reshape
(
-
1
,
seq_length
)
topk_block_attention_mask
=
long_tensor
(
topk_block_attention_mask
).
reshape
(
-
1
,
seq_length
)
# print('Block token shape: ', topk_block_tokens.shape, flush=True)
# [batch_size x k x embed_size]
true_model
=
self
.
retriever
.
ict_model
.
module
.
module
fresh_block_logits
=
mpu
.
checkpoint
(
true_model
.
embed_block
,
topk_block_tokens
,
topk_block_attention_mask
)
fresh_block_logits
=
fresh_block_logits
.
reshape
(
batch_size
,
self
.
top_k
,
-
1
)
# print('Fresh block logits shape: ', fresh_block_logits.shape, flush=True)
# [batch_size x embed_size x 1]
query_logits
=
mpu
.
checkpoint
(
true_model
.
embed_query
,
tokens
,
attention_mask
).
unsqueeze
(
2
)
# print('Query logits shape: ', query_logits.shape, flush=True)
# [batch_size x k]
fresh_block_scores
=
torch
.
matmul
(
fresh_block_logits
,
query_logits
).
squeeze
()
# print('Block score shape: ', fresh_block_scores.shape, flush=True)
block_probs
=
F
.
softmax
(
fresh_block_scores
,
dim
=
1
)
# [batch_size * k x seq_length]
tokens
=
torch
.
stack
([
tokens
.
unsqueeze
(
1
)]
*
self
.
top_k
,
dim
=
1
).
reshape
(
-
1
,
seq_length
)
#assert all(tokens[i] == tokens[0] for i in range(self.top_k))
#assert all(tokens[i] == tokens[self.top_k] for i in range(self.top_k, 2 * self.top_k))
#assert not any(tokens[i] == tokens[0] for i in range(self.top_k, batch_size * self.top_k))
attention_mask
=
torch
.
stack
([
attention_mask
.
unsqueeze
(
1
)]
*
self
.
top_k
,
dim
=
1
).
reshape
(
-
1
,
seq_length
)
# [batch_size * k x 2 * seq_length]
lm_input_batch_shape
=
(
batch_size
*
self
.
top_k
,
2
*
seq_length
)
all_tokens
=
torch
.
zeros
(
lm_input_batch_shape
).
long
().
cuda
()
all_attention_mask
=
all_tokens
.
clone
()
all_token_types
=
all_tokens
.
clone
()
#all_tokens = torch.cat((tokens, topk_block_tokens), axis=1)
#all_attention_mask = torch.cat((attention_mask, topk_block_attention_mask), axis=1)
#all_token_types = torch.zeros(all_tokens.shape).type(torch.int64).cuda()
query_lengths
=
torch
.
sum
(
attention_mask
,
axis
=
1
)
# all blocks (including null ones) will have two SEP tokens
block_sep_indices
=
(
topk_block_tokens
==
dset
.
sep_id
).
nonzero
().
reshape
(
batch_size
*
self
.
top_k
,
2
,
2
)
# block body starts after the first SEP
block_starts
=
block_sep_indices
[:,
0
,
1
]
+
1
# block body ends after the second SEP
block_ends
=
block_sep_indices
[:,
1
,
1
]
+
1
# block_lengths = torch.sum(topk_block_attention_mask, axis=1)
for
row_num
in
range
(
all_tokens
.
shape
[
0
]):
q_len
=
query_lengths
[
row_num
]
b_start
=
block_starts
[
row_num
]
b_end
=
block_ends
[
row_num
]
# new tokens = CLS + query + SEP + block + SEP
new_tokens_length
=
q_len
+
b_end
-
b_start
# splice query and block tokens accordingly
all_tokens
[
row_num
,
:
q_len
]
=
tokens
[
row_num
,
:
q_len
]
all_tokens
[
row_num
,
q_len
:
new_tokens_length
]
=
topk_block_tokens
[
row_num
,
b_start
:
b_end
]
all_tokens
[
row_num
,
new_tokens_length
:]
=
self
.
retriever
.
ict_dataset
.
pad_id
# print(dset.decode_tokens(detach(all_tokens[row_num]).tolist()), '\n', flush=True)
all_attention_mask
[
row_num
,
:
new_tokens_length
]
=
1
all_attention_mask
[
row_num
,
new_tokens_length
:]
=
0
# [batch_size x k x 2 * seq_length x vocab_size]
lm_logits
,
_
=
self
.
lm_model
.
forward
(
all_tokens
,
all_attention_mask
,
all_token_types
)
lm_logits
=
lm_logits
.
reshape
(
batch_size
,
self
.
top_k
,
2
*
seq_length
,
-
1
)
if
return_topk_block_tokens
:
return
lm_logits
,
block_probs
,
ret_topk_block_tokens
return
lm_logits
,
block_probs
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_
=
{}
state_dict_
[
self
.
_lm_key
]
=
self
.
lm_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
state_dict_
[
self
.
_retriever_key
]
=
self
.
retriever
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Load the state dicts of each of the models"""
self
.
lm_model
.
load_state_dict
(
state_dict
[
self
.
_lm_key
],
strict
)
self
.
retriever
.
load_state_dict
(
state_dict
[
self
.
_retriever_key
],
strict
)
class
REALMRetriever
(
MegatronModule
):
"""Retriever which uses a pretrained ICTBertModel and a HashedIndex"""
def
__init__
(
self
,
ict_model
,
ict_dataset
,
block_data
,
hashed_index
,
top_k
=
5
):
super
(
REALMRetriever
,
self
).
__init__
()
self
.
ict_model
=
ict_model
self
.
ict_dataset
=
ict_dataset
self
.
block_data
=
block_data
self
.
hashed_index
=
hashed_index
self
.
top_k
=
top_k
self
.
_ict_key
=
'ict_model'
def
reload_index
(
self
):
args
=
get_args
()
self
.
block_data
=
BlockData
.
load_from_file
(
args
.
block_data_path
)
print
(
"resetting index"
,
flush
=
True
)
self
.
hashed_index
.
reset_index
()
self
.
hashed_index
.
add_block_embed_data
(
self
.
block_data
)
def
prep_query_text_for_retrieval
(
self
,
query_text
):
padless_max_len
=
self
.
ict_dataset
.
max_seq_length
-
2
query_tokens
=
self
.
ict_dataset
.
encode_text
(
query_text
)[:
padless_max_len
]
query_tokens
,
query_pad_mask
=
self
.
ict_dataset
.
concat_and_pad_tokens
(
query_tokens
)
query_tokens
=
torch
.
cuda
.
LongTensor
(
np
.
array
(
query_tokens
).
reshape
(
1
,
-
1
))
query_pad_mask
=
torch
.
cuda
.
LongTensor
(
np
.
array
(
query_pad_mask
).
reshape
(
1
,
-
1
))
return
query_tokens
,
query_pad_mask
def
retrieve_evidence_blocks_text
(
self
,
query_text
):
"""Get the top k evidence blocks for query_text in text form"""
print
(
"-"
*
100
)
print
(
"Query: "
,
query_text
)
query_tokens
,
query_pad_mask
=
self
.
prep_query_text_for_retrieval
(
query_text
)
topk_block_tokens
,
_
=
self
.
retrieve_evidence_blocks
(
query_tokens
,
query_pad_mask
)
for
i
,
block
in
enumerate
(
topk_block_tokens
[
0
]):
block_text
=
self
.
ict_dataset
.
decode_tokens
(
block
)
print
(
'
\n
> Block {}: {}'
.
format
(
i
,
block_text
))
def
retrieve_evidence_blocks
(
self
,
query_tokens
,
query_pad_mask
,
query_block_indices
=
None
,
include_null_doc
=
False
):
"""Embed blocks to be used in a forward pass"""
with
torch
.
no_grad
():
if
hasattr
(
self
.
ict_model
,
'module'
):
true_model
=
self
.
ict_model
.
module
if
hasattr
(
true_model
,
'module'
):
true_model
=
true_model
.
module
else
:
true_model
=
self
.
ict_model
# print("true model: ", true_model, flush=True)
query_embeds
=
self
.
ict_model
(
query_tokens
,
query_pad_mask
,
None
,
None
,
only_query
=
True
)
_
,
block_indices
=
self
.
hashed_index
.
search_mips_index
(
query_embeds
,
top_k
=
self
.
top_k
,
reconstruct
=
False
)
all_topk_tokens
,
all_topk_pad_masks
=
[],
[]
# this will result in no candidate exclusion
if
query_block_indices
is
None
:
query_block_indices
=
[
-
1
]
*
len
(
block_indices
)
top_k_offset
=
int
(
include_null_doc
)
for
query_idx
,
indices
in
enumerate
(
block_indices
):
# [k x meta_dim]
# exclude trivial candidate if it appears, else just trim the weakest in the top-k
topk_metas
=
[
self
.
block_data
.
meta_data
[
idx
]
for
idx
in
indices
if
idx
!=
query_block_indices
[
query_idx
]]
topk_block_data
=
[
self
.
ict_dataset
.
get_block
(
*
block_meta
)
for
block_meta
in
topk_metas
[:
self
.
top_k
-
top_k_offset
]]
if
include_null_doc
:
topk_block_data
.
append
(
self
.
ict_dataset
.
get_null_block
())
topk_tokens
,
topk_pad_masks
=
zip
(
*
topk_block_data
)
all_topk_tokens
.
append
(
np
.
array
(
topk_tokens
))
all_topk_pad_masks
.
append
(
np
.
array
(
topk_pad_masks
))
# [batch_size x k x seq_length]
return
np
.
array
(
all_topk_tokens
),
np
.
array
(
all_topk_pad_masks
)
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_
=
{}
state_dict_
[
self
.
_ict_key
]
=
self
.
ict_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Load the state dicts of each of the models"""
self
.
ict_model
.
load_state_dict
(
state_dict
[
self
.
_ict_key
],
strict
)
class
ICTBertModel
(
MegatronModule
):
"""Bert-based module for Inverse Cloze task."""
def
__init__
(
self
,
...
...
@@ -341,10 +54,6 @@ class ICTBertModel(MegatronModule):
block_logits
=
self
.
embed_block
(
block_tokens
,
block_attention_mask
)
return
query_logits
,
block_logits
# [batch x embed] * [embed x batch]
# retrieval_scores = query_logits.matmul(torch.transpose(block_logits, 0, 1))
# return retrieval_scores
def
embed_query
(
self
,
query_tokens
,
query_attention_mask
):
"""Embed a batch of tokens using the query model"""
if
self
.
use_query_model
:
...
...
@@ -391,10 +100,8 @@ class ICTBertModel(MegatronModule):
state_dict
[
self
.
_block_key
],
strict
=
strict
)
def
init_state_dict_from_bert
(
self
):
"""Initialize the state from a pretrained BERT model on iteration zero of ICT pretraining"""
args
=
get_args
()
import
os
from
megatron
import
mpu
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
bert_load
)
if
not
os
.
path
.
isfile
(
tracker_filename
):
raise
FileNotFoundError
(
"Could not find BERT load for ICT"
)
...
...
@@ -412,8 +119,11 @@ class ICTBertModel(MegatronModule):
except
BaseException
:
raise
ValueError
(
"Could not load checkpoint"
)
# load the LM state dict into each model
model_dict
=
state_dict
[
'model'
][
'language_model'
]
self
.
query_model
.
language_model
.
load_state_dict
(
model_dict
)
self
.
block_model
.
language_model
.
load_state_dict
(
model_dict
)
# give each model the same ict_head to begin with as well
query_ict_head_state_dict
=
self
.
state_dict_for_save_checkpoint
()[
self
.
_query_key
][
'ict_head'
]
self
.
block_model
.
ict_head
.
load_state_dict
(
query_ict_head_state_dict
)
megatron/mpu/data.py
View file @
32bb4edc
...
...
@@ -78,7 +78,7 @@ def broadcast_data(keys, data, datatype):
members of the same model parallel group.
Arguments:
keys: list of keys in the data dictionary to be broadcasted
keys: list of keys in the data di
s
ctionary to be broadcasted
data: data dictionary of string keys and cpu tensor values.
datatype: torch data type of all tensors in data associated
with keys.
...
...
megatron/mpu/initialize.py
View file @
32bb4edc
...
...
@@ -16,7 +16,6 @@
"""Model and data parallel groups."""
import
datetime
import
torch
from
.utils
import
ensure_divisibility
...
...
@@ -27,11 +26,6 @@ _MODEL_PARALLEL_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP
=
None
_GLOO_COMM_GROUP
=
None
_TRAIN_GROUP
=
None
_INDEX_GROUP
=
None
_INDEX_READY
=
None
# These values enable us to change the mpu sizes on the fly.
_MPU_WORLD_SIZE
=
None
_MPU_RANK
=
None
...
...
@@ -102,13 +96,6 @@ def get_model_parallel_group():
return
_MODEL_PARALLEL_GROUP
def
set_model_parallel_group
(
group
):
global
_MODEL_PARALLEL_GROUP
assert
_MODEL_PARALLEL_GROUP
is
None
,
\
'model parallel group has already been initialized'
_MODEL_PARALLEL_GROUP
=
group
def
get_data_parallel_group
():
"""Get the data parallel group the caller rank belongs to."""
assert
_DATA_PARALLEL_GROUP
is
not
None
,
\
...
...
@@ -116,13 +103,6 @@ def get_data_parallel_group():
return
_DATA_PARALLEL_GROUP
def
set_data_parallel_group
(
group
):
global
_DATA_PARALLEL_GROUP
assert
_DATA_PARALLEL_GROUP
is
None
,
\
'data parallel group has already been initialized'
_DATA_PARALLEL_GROUP
=
group
def
set_model_parallel_world_size
(
world_size
):
"""Set the model parallel size"""
global
_MPU_WORLD_SIZE
...
...
@@ -175,40 +155,3 @@ def destroy_model_parallel():
_MODEL_PARALLEL_GROUP
=
None
global
_DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP
=
None
def
init_realm_groups
(
max_training_rank
,
world_size
):
global
_GLOO_COMM_GROUP
_GLOO_COMM_GROUP
=
torch
.
distributed
.
new_group
(
list
(
range
(
world_size
)),
backend
=
"gloo"
,
timeout
=
datetime
.
timedelta
(
0
,
7200
))
global
_TRAIN_GROUP
_TRAIN_GROUP
=
torch
.
distributed
.
new_group
(
list
(
range
(
max_training_rank
)))
global
_INDEX_GROUP
_INDEX_GROUP
=
torch
.
distributed
.
new_group
(
list
(
range
(
max_training_rank
,
world_size
)))
global
_INDEX_READY
_INDEX_READY
=
torch
.
zeros
(
1
)
def
get_gloo_comm_group
():
global
_GLOO_COMM_GROUP
assert
_GLOO_COMM_GROUP
is
not
None
return
_GLOO_COMM_GROUP
def
get_train_group
():
global
_TRAIN_GROUP
assert
_TRAIN_GROUP
is
not
None
return
_TRAIN_GROUP
def
get_index_group
():
global
_INDEX_GROUP
assert
_INDEX_GROUP
is
not
None
return
_INDEX_GROUP
def
get_index_ready
():
global
_INDEX_READY
assert
_INDEX_READY
is
not
None
return
_INDEX_READY
megatron/training.py
View file @
32bb4edc
...
...
@@ -18,8 +18,6 @@
from
datetime
import
datetime
import
math
import
sys
import
time
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
apex.optimizers
import
FusedAdam
as
Adam
...
...
@@ -37,19 +35,14 @@ from megatron.initialize import initialize_megatron
from
megatron.learning_rates
import
AnnealingLR
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
get_params_for_weight_decay_optimization
from
megatron.mpu.initialize
import
get_index_ready
,
get_train_group
,
get_data_parallel_group
,
get_gloo_comm_group
from
megatron.model.realm_model
import
ICTBertModel
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
make_data_loader
from
megatron.utils
import
report_memory
INDEX_READY
=
None
def
pretrain
(
train_valid_test_dataset_provider
,
model_provider
,
forward_step_func
,
extra_args_provider
=
None
,
args_defaults
=
{},
initializer_func
=
None
):
forward_step_func
,
extra_args_provider
=
None
,
args_defaults
=
{}):
"""Main training program.
This function will run the followings in the order provided:
...
...
@@ -75,14 +68,8 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
"""
# Initalize and get arguments, timers, and Tensorboard writer.
if
initializer_func
is
None
:
initialize_megatron
(
extra_args_provider
=
extra_args_provider
,
args_defaults
=
args_defaults
)
else
:
initializer_func
(
extra_args_provider
=
extra_args_provider
,
args_defaults
=
args_defaults
)
global
INDEX_READY
INDEX_READY
=
get_index_ready
()
initialize_megatron
(
extra_args_provider
=
extra_args_provider
,
args_defaults
=
args_defaults
)
args
=
get_args
()
timers
=
get_timers
()
...
...
@@ -232,10 +219,8 @@ def setup_model_and_optimizer(model_provider_func):
args
.
iteration
=
0
if
args
.
iteration
==
0
and
isinstance
(
model
.
module
.
module
,
ICTBertModel
):
print
(
"
Yes, located IC
T model"
,
flush
=
True
)
print
(
"
Initializing ICT from pretrained BER
T model"
,
flush
=
True
)
model
.
module
.
module
.
init_state_dict_from_bert
()
elif
args
.
iteration
==
0
:
print
(
"Ooops"
,
flush
=
True
)
return
model
,
optimizer
,
lr_scheduler
...
...
@@ -244,15 +229,12 @@ def backward_step(optimizer, model, loss):
"""Backward step."""
args
=
get_args
()
timers
=
get_timers
()
# torch.cuda.synchronize()
# Backward pass.
#
optimizer.zero_grad(set_grads_to_None=True)
optimizer
.
zero_grad
(
set_grads_to_None
=
True
)
if
args
.
fp16
:
optimizer
.
zero_grad
(
set_grads_to_None
=
True
)
optimizer
.
backward
(
loss
,
update_master_grads
=
False
)
else
:
optimizer
.
zero_grad
()
loss
.
backward
()
# All-reduce if needed.
...
...
@@ -261,9 +243,11 @@ def backward_step(optimizer, model, loss):
model
.
allreduce_params
(
reduce_after
=
False
,
fp32_allreduce
=
args
.
fp32_allreduce
)
timers
(
'allreduce'
).
stop
()
# Update master gradients.
if
args
.
fp16
:
optimizer
.
update_master_grads
()
# Clipping gradients helps prevent the exploding gradient.
if
args
.
clip_grad
>
0
:
if
not
args
.
fp16
:
...
...
@@ -283,12 +267,11 @@ def train_step(forward_step_func, data_iterator,
loss
,
loss_reduced
=
forward_step_func
(
data_iterator
,
model
)
timers
(
'forward'
).
stop
()
# Calculate gradients, reduce across processes, and clip.
timers
(
'backward'
).
start
()
backward_step
(
optimizer
,
model
,
loss
)
timers
(
'backward'
).
stop
()
# Calculate gradients, reduce across processes, and clip.
# Update parameters.
timers
(
'optimizer'
).
start
()
optimizer
.
step
()
...
...
@@ -383,54 +366,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
timers
(
'interval time'
).
start
()
report_memory_flag
=
True
global
INDEX_READY
print
(
'>>> Starting train()'
,
flush
=
True
)
# start off by posting a receive call which will be answered.
# synchronize for start
if
args
.
max_training_rank
is
not
None
:
torch
.
distributed
.
broadcast
(
INDEX_READY
,
0
,
group
=
get_gloo_comm_group
())
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
group
=
get_gloo_comm_group
(),
async_op
=
True
)
last_reload_iteration
=
iteration
while
iteration
<
args
.
train_iters
:
if
args
.
max_training_rank
is
not
None
and
iteration
>=
last_reload_iteration
+
500
and
not
recv_handle
.
is_completed
():
time
.
sleep
(
5
)
continue
# this only applies for realm right here
if
args
.
max_training_rank
is
not
None
and
recv_handle
.
is_completed
():
# should add check that INDEX_READY == 1 but what else could be happening
true_model
=
model
if
hasattr
(
true_model
,
'module'
):
true_model
=
true_model
.
module
if
hasattr
(
true_model
,
'module'
):
true_model
=
true_model
.
module
print
(
"> Saving model and reloading index"
,
flush
=
True
)
if
args
.
rank
==
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
true_model
.
retriever
.
reload_index
()
if
args
.
rank
==
0
:
INDEX_READY
=
1
-
INDEX_READY
torch
.
cuda
.
synchronize
()
# send handle
torch
.
distributed
.
broadcast
(
INDEX_READY
,
0
,
group
=
get_gloo_comm_group
())
torch
.
distributed
.
barrier
(
get_data_parallel_group
())
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
group
=
get_gloo_comm_group
(),
async_op
=
True
)
last_reload_iteration
=
iteration
elif
iteration
<
20
:
print
(
"moving right along"
,
flush
=
True
)
# report_memory("iteration {}".format(iteration))
loss_dict
,
skipped_iter
=
train_step
(
forward_step_func
,
train_data_iterator
,
model
,
optimizer
,
lr_scheduler
)
skipped_iters
+=
skipped_iter
iteration
+=
1
...
...
@@ -463,7 +404,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
iteration
,
False
)
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
torch
.
distributed
.
barrier
(
get_data_parallel_group
()
)
torch
.
distributed
.
barrier
()
time_str
=
datetime
.
now
().
strftime
(
'%Y-%m-%d %H:%M:%S'
)
rank
=
torch
.
distributed
.
get_rank
()
print_rank_0
(
'rank: {} | time: {} | exiting the program at '
...
...
megatron/utils.py
View file @
32bb4edc
...
...
@@ -25,7 +25,6 @@ from megatron import mpu
from
megatron
import
print_rank_0
from
megatron.checkpointing
import
save_checkpoint
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.mpu.initialize
import
get_data_parallel_group
from
megatron.fp16
import
FP16_Optimizer
...
...
@@ -33,13 +32,8 @@ def reduce_losses(losses):
"""Reduce a tensor of losses across all GPUs."""
reduced_losses
=
torch
.
cat
(
[
loss
.
clone
().
detach
().
view
(
1
)
for
loss
in
losses
])
torch
.
distributed
.
all_reduce
(
reduced_losses
,
group
=
get_data_parallel_group
())
args
=
get_args
()
if
args
.
max_training_rank
is
not
None
:
num_trainers
=
args
.
max_training_rank
else
:
num_trainers
=
torch
.
distributed
.
get_world_size
()
reduced_losses
=
reduced_losses
/
num_trainers
torch
.
distributed
.
all_reduce
(
reduced_losses
)
reduced_losses
=
reduced_losses
/
torch
.
distributed
.
get_world_size
()
return
reduced_losses
...
...
@@ -84,7 +78,7 @@ def check_adlr_autoresume_termination(iteration, model,
args
=
get_args
()
autoresume
=
get_adlr_autoresume
()
# Add barrier to ensure consistnecy.
torch
.
distributed
.
barrier
(
get_data_parallel_group
()
)
torch
.
distributed
.
barrier
()
if
autoresume
.
termination_requested
():
if
args
.
save
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
...
...
pretrain_bert_ict.py
View file @
32bb4edc
...
...
@@ -36,6 +36,7 @@ def model_provider(only_query_model=False, only_block_model=False):
args
=
get_args
()
print_rank_0
(
'building BERT models ...'
)
# simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes
model
=
ICTBertModel
(
ict_head_size
=
128
,
num_tokentypes
=
2
,
...
...
@@ -93,19 +94,16 @@ def forward_step(data_iterator, model):
all_query_logits
=
torch
.
zeros
(
all_logits_shape
).
type
(
query_logits
.
dtype
).
cuda
()
all_block_logits
=
all_query_logits
.
clone
().
cuda
()
# record this processes' data and then merge with other processes below
all_query_logits
[
args
.
rank
*
batch_size
:(
args
.
rank
+
1
)
*
batch_size
]
=
query_logits
all_block_logits
[
args
.
rank
*
batch_size
:(
args
.
rank
+
1
)
*
batch_size
]
=
block_logits
# print(all_query_logits[:, :5], flush=True)
# print(all_block_logits[:, :5], flush=True)
dist
.
all_reduce
(
all_query_logits
)
dist
.
all_reduce
(
all_block_logits
)
# print(all_query_logits[:, :5], flush=True)
# print(all_block_logits[:, :5], flush=True)
# scores are inner products between query and block embeddings
retrieval_scores
=
all_query_logits
.
float
().
matmul
(
torch
.
transpose
(
all_block_logits
,
0
,
1
).
float
())
softmaxed
=
F
.
softmax
(
retrieval_scores
,
dim
=
1
)
sorted_vals
,
sorted_indices
=
torch
.
topk
(
softmaxed
,
k
=
softmaxed
.
shape
[
1
],
sorted
=
True
)
def
topk_acc
(
k
):
...
...
@@ -113,11 +111,6 @@ def forward_step(data_iterator, model):
top_accs
=
[
topk_acc
(
k
)
for
k
in
[
1
,
8
,
20
,
100
]]
retrieval_loss
=
torch
.
nn
.
CrossEntropyLoss
()(
retrieval_scores
,
torch
.
arange
(
global_batch_size
).
long
().
cuda
())
# correct_probs = torch.gather(softmaxed, 1, torch.arange(global_batch_size).long().cuda().reshape(-1, 1))
# assert correct_probs[3] == softmaxed[3, 3]
# retrieval_loss = -torch.sum(torch.log(correct_probs)) / global_batch_size
reduced_losses
=
reduce_losses
([
retrieval_loss
,
*
top_accs
])
stats_dict
=
{
'retrieval loss'
:
reduced_losses
[
0
],
...
...
pretrain_realm.py
deleted
100644 → 0
View file @
674814a5
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain BERT for Inverse Cloze Task"""
import
torch
import
torch.nn.functional
as
F
from
indexer
import
load_ict_checkpoint
,
get_ict_dataset
from
megatron.data.realm_index
import
BlockData
,
RandProjectionLSHIndex
,
FaissMIPSIndex
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.model
import
REALMBertModel
,
REALMRetriever
from
megatron.training
import
pretrain
from
megatron.utils
import
reduce_losses
,
report_memory
from
megatron
import
mpu
from
indexer
import
initialize_and_run_async_megatron
num_batches
=
0
def
model_provider
():
"""Build the model."""
args
=
get_args
()
print_rank_0
(
'building REALM models ...'
)
try
:
ict_model
=
load_ict_checkpoint
(
from_realm_chkpt
=
True
)
except
:
ict_model
=
load_ict_checkpoint
(
from_realm_chkpt
=
False
)
ict_dataset
=
get_ict_dataset
(
use_titles
=
False
)
all_block_data
=
BlockData
.
load_from_file
(
args
.
block_data_path
)
# hashed_index = RandProjectionLSHIndex.load_from_file(args.block_index_path)
hashed_index
=
FaissMIPSIndex
(
index_type
=
'flat_ip'
,
embed_size
=
128
,
use_gpu
=
args
.
faiss_use_gpu
)
hashed_index
.
add_block_embed_data
(
all_block_data
)
# top_k + 1 because we may need to exclude trivial candidate
retriever
=
REALMRetriever
(
ict_model
,
ict_dataset
,
all_block_data
,
hashed_index
,
args
.
block_top_k
)
model
=
REALMBertModel
(
retriever
)
return
model
def
get_batch
(
data_iterator
):
# Items and their type.
keys
=
[
'tokens'
,
'labels'
,
'loss_mask'
,
'pad_mask'
,
'query_block_indices'
]
datatype
=
torch
.
int64
# Broadcast data.
if
data_iterator
is
None
:
data
=
None
else
:
data
=
next
(
data_iterator
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
tokens
=
data_b
[
'tokens'
].
long
()
labels
=
data_b
[
'labels'
].
long
()
loss_mask
=
data_b
[
'loss_mask'
].
long
()
pad_mask
=
data_b
[
'pad_mask'
].
long
()
query_block_indices
=
data_b
[
'query_block_indices'
].
long
()
return
tokens
,
labels
,
loss_mask
,
pad_mask
,
query_block_indices
def
get_qa_batch
(
data_iterator
):
question_tokens
,
question_attention_mask
,
answer_tokens
,
answer_token_lengths
=
next
(
data_iterator
)
return
question_tokens
,
question_attention_mask
,
answer_tokens
,
answer_token_lengths
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
timers
=
get_timers
()
# Get the batch.
timers
(
'batch generator'
).
start
()
tokens
,
labels
,
loss_mask
,
pad_mask
,
query_block_indices
=
get_batch
(
data_iterator
)
timers
(
'batch generator'
).
stop
()
# Forward model.
lm_logits
,
block_probs
=
model
(
tokens
,
pad_mask
,
query_block_indices
)
with
torch
.
no_grad
():
max_retrieval_utility
,
top_retrieval_utility
,
avg_retrieval_utility
=
mpu
.
checkpoint
(
get_retrieval_utility
,
lm_logits
,
block_probs
,
labels
,
loss_mask
)
# P(y|x) = sum_z(P(y|z, x) * P(z|x))
null_block_probs
=
torch
.
mean
(
block_probs
[:,
block_probs
.
shape
[
1
]
-
1
])
block_probs
=
block_probs
.
unsqueeze
(
2
).
unsqueeze
(
3
).
expand_as
(
lm_logits
)
lm_logits
=
torch
.
sum
(
lm_logits
*
block_probs
,
dim
=
1
)[:,
:
labels
.
shape
[
1
]]
lm_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
contiguous
().
float
(),
labels
.
contiguous
())
lm_loss
=
torch
.
sum
(
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
reduced_loss
=
reduce_losses
([
lm_loss
,
max_retrieval_utility
,
top_retrieval_utility
,
avg_retrieval_utility
,
null_block_probs
])
# torch.cuda.synchronize()
return
lm_loss
,
{
'lm_loss'
:
reduced_loss
[
0
],
'max_ru'
:
reduced_loss
[
1
],
'top_ru'
:
reduced_loss
[
2
],
'avg_ru'
:
reduced_loss
[
3
],
'null_prob'
:
reduced_loss
[
4
]}
def
get_retrieval_utility
(
lm_logits
,
block_probs
,
labels
,
loss_mask
):
"""log P(y | z, x) - log P(y | null, x)"""
# [batch x seq_len x vocab_size]
lm_logits
=
lm_logits
[:,
:,
:
labels
.
shape
[
1
],
:]
#non_null_block_probs = block_probs[:, :-1]
#non_null_block_probs /= torch.sum(non_null_block_probs, axis=1, keepdim=True)
# non_null_block_probs = non_null_block_probsexpand_as(lm_logits[:, :-1, :, :])
null_block_lm_logits
=
lm_logits
[:,
-
1
,
:,
:]
null_block_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
null_block_lm_logits
.
contiguous
().
float
(),
labels
.
contiguous
())
null_block_loss
=
torch
.
sum
(
null_block_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
retrieved_block_losses
=
[]
for
block_num
in
range
(
lm_logits
.
shape
[
1
]
-
1
):
retrieved_block_lm_logits
=
lm_logits
[:,
block_num
,
:,
:]
retrieved_block_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
retrieved_block_lm_logits
.
contiguous
().
float
(),
labels
.
contiguous
())
#retrieved_block_loss_ *= non_null_block_probs[:, block_num].reshape(-1, 1)
retrieved_block_loss
=
torch
.
sum
(
retrieved_block_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
retrieved_block_losses
.
append
(
retrieved_block_loss
)
avg_retrieved_block_loss
=
torch
.
sum
(
torch
.
cuda
.
FloatTensor
(
retrieved_block_losses
))
/
(
lm_logits
.
shape
[
1
]
-
1
)
max_retrieval_utility
=
null_block_loss
-
min
(
retrieved_block_losses
)
top_retrieval_utility
=
null_block_loss
-
retrieved_block_losses
[
0
]
avg_retrieval_utility
=
null_block_loss
-
avg_retrieved_block_loss
return
max_retrieval_utility
,
top_retrieval_utility
,
avg_retrieval_utility
def
qa_forward_step
(
data_iterator
,
model
):
timers
=
get_timers
()
# this dataset interface needs to be implemented
timers
(
'batch generator'
).
start
()
question_tokens
,
question_attention_mask
,
answer_tokens
,
answer_token_lengths
=
get_qa_batch
(
data_iterator
)
timers
(
'batch generator'
).
stop
()
batch_span_logits
,
batch_loss_masks
,
block_probs
=
model
(
question_tokens
,
question_attention_mask
,
answer_tokens
,
answer_token_lengths
)
# [batch_size x k x num_spans]
block_probs
=
block_probs
.
unsqueeze
(
2
).
expand_as
(
batch_span_logits
)
batch_span_probs
=
F
.
softmax
(
batch_span_logits
,
dim
=
2
)
reduced_block_span_probs
=
torch
.
sum
(
batch_span_probs
*
block_probs
,
dim
=
1
)
qa_span_loss_
=
-
torch
.
log
(
reduced_block_span_probs
)
qa_span_loss
=
torch
.
sum
(
qa_span_loss_
.
view
(
-
1
)
*
batch_loss_masks
)
def
train_valid_test_datasets_provider
(
train_val_test_num_samples
):
"""Build train, valid and test datasets."""
args
=
get_args
()
print_rank_0
(
'> building train, validation, and test datasets '
'for BERT ...'
)
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets
(
data_prefix
=
args
.
data_path
,
data_impl
=
args
.
data_impl
,
splits_string
=
args
.
split
,
train_valid_test_num_samples
=
train_val_test_num_samples
,
max_seq_length
=
args
.
seq_length
,
masked_lm_prob
=
args
.
mask_prob
,
short_seq_prob
=
args
.
short_seq_prob
,
seed
=
args
.
seed
,
skip_warmup
=
(
not
args
.
mmap_warmup
),
dataset_type
=
'realm'
)
print_rank_0
(
"> finished creating BERT ICT datasets ..."
)
return
train_ds
,
valid_ds
,
test_ds
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
},
initializer_func
=
initialize_and_run_async_megatron
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment