Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
32bb4edc
Commit
32bb4edc
authored
Jun 05, 2020
by
Neel Kant
Browse files
Prune changes to only be related to ICT
parent
674814a5
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
22 additions
and
637 deletions
+22
-637
megatron/model/realm_model.py
megatron/model/realm_model.py
+6
-296
megatron/mpu/data.py
megatron/mpu/data.py
+1
-1
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+0
-57
megatron/training.py
megatron/training.py
+9
-68
megatron/utils.py
megatron/utils.py
+3
-9
pretrain_bert_ict.py
pretrain_bert_ict.py
+3
-10
pretrain_realm.py
pretrain_realm.py
+0
-196
No files found.
megatron/model/realm_model.py
View file @
32bb4edc
import
os
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
,
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.data.realm_index
import
detach
,
BlockData
,
FaissMIPSIndex
from
megatron.data.realm_index
import
detach
,
BlockData
,
FaissMIPSIndex
from
megatron.model
import
BertModel
from
megatron.model
import
BertModel
from
megatron.model.utils
import
get_linear_layer
,
init_method_normal
from
megatron.model.utils
import
get_linear_layer
,
init_method_normal
...
@@ -12,294 +13,6 @@ from megatron.utils import report_memory
...
@@ -12,294 +13,6 @@ from megatron.utils import report_memory
from
megatron
import
mpu
from
megatron
import
mpu
class
REALMAnswerSpanModel
(
MegatronModule
):
def
__init__
(
self
,
realm_model
,
mlp_hidden_size
=
64
):
super
(
REALMAnswerSpanModel
,
self
).
__init__
()
self
.
realm_model
=
realm_model
self
.
mlp_hidden_size
=
mlp_hidden_size
args
=
get_args
()
init_method
=
init_method_normal
(
args
.
init_method_std
)
self
.
fc1
=
get_linear_layer
(
2
*
args
.
hidden_size
,
self
.
mlp_hidden_size
,
init_method
)
self
.
_fc1_key
=
'fc1'
self
.
fc2
=
get_linear_layer
(
self
.
mlp_hidden_size
,
1
,
init_method
)
self
.
_fc2_key
=
'fc2'
max_length
=
10
self
.
start_ends
=
[]
for
length
in
range
(
max_length
):
self
.
start_ends
.
extend
([(
i
,
i
+
length
)
for
i
in
range
(
288
-
length
)])
def
forward
(
self
,
question_tokens
,
question_attention_mask
,
answer_tokens
,
answer_token_lengths
):
lm_logits
,
block_probs
,
topk_block_tokens
=
self
.
realm_model
(
question_tokens
,
question_attention_mask
,
query_block_indices
=
None
,
return_topk_block_tokens
=
True
)
batch_span_reps
,
batch_loss_masks
=
[],
[]
# go through batch one-by-one
for
i
in
range
(
len
(
answer_token_lengths
)):
answer_length
=
answer_token_lengths
[
i
]
answer_span_tokens
=
answer_tokens
[
i
][:
answer_length
]
span_reps
,
loss_masks
=
[],
[]
# go through the top k for the batch item
for
logits
,
block_tokens
in
zip
(
lm_logits
[
i
],
topk_block_tokens
[
i
]):
block_logits
=
logits
[
len
(
logits
)
/
2
:]
span_starts
=
range
(
len
(
block_tokens
)
-
(
answer_length
-
1
))
# record the start, end indices of spans which match the answer
matching_indices
=
set
([
(
idx
,
idx
+
answer_length
-
1
)
for
idx
in
span_starts
if
np
.
array_equal
(
block_tokens
[
idx
:
idx
+
answer_length
],
answer_span_tokens
)
])
# create a mask for computing the loss on P(y | z, x)
# [num_spans]
loss_masks
.
append
(
torch
.
LongTensor
([
int
(
idx_pair
in
matching_indices
)
for
idx_pair
in
self
.
start_ends
]))
# get all of the candidate spans that need to be fed to MLP
# [num_spans x 2 * embed_size]
span_reps
.
append
([
torch
.
cat
((
block_logits
[
s
],
block_logits
[
e
]))
for
(
s
,
e
)
in
self
.
start_ends
])
# data for all k blocks for a single batch item
# [k x num_spans]
batch_loss_masks
.
append
(
torch
.
stack
(
loss_masks
))
# [k x num_spans x 2 * embed_size]
batch_span_reps
.
append
(
torch
.
stack
(
span_reps
))
# data for all batch items
# [batch_size x k x num_spans]
batch_loss_masks
=
torch
.
stack
(
batch_loss_masks
)
batch_span_reps
=
torch
.
stack
(
batch_span_reps
)
# [batch_size x k x num_spans]
batch_span_logits
=
self
.
fc2
(
self
.
fc1
(
batch_span_reps
)).
squeeze
()
return
batch_span_logits
,
batch_loss_masks
,
block_probs
# block_probs = block_probs.unsqueeze(2).unsqueeze(3).expand_as(lm_logits)
# lm_logits = torch.sum(lm_logits * block_probs, dim=1)
class
REALMBertModel
(
MegatronModule
):
def
__init__
(
self
,
retriever
):
super
(
REALMBertModel
,
self
).
__init__
()
bert_args
=
dict
(
num_tokentypes
=
2
,
add_binary_head
=
False
,
parallel_output
=
True
)
self
.
lm_model
=
BertModel
(
**
bert_args
)
load_checkpoint
(
self
.
lm_model
,
optimizer
=
None
,
lr_scheduler
=
None
)
self
.
_lm_key
=
'realm_lm'
self
.
retriever
=
retriever
self
.
top_k
=
self
.
retriever
.
top_k
self
.
_retriever_key
=
'retriever'
def
forward
(
self
,
tokens
,
attention_mask
,
query_block_indices
,
return_topk_block_tokens
=
False
):
# print("\nNEW FORWARD", '-' * 100, flush=True)
dset
=
self
.
retriever
.
ict_dataset
det_tokens
=
detach
(
tokens
)[
0
].
tolist
()
det_attention
=
detach
(
attention_mask
)[
0
].
tolist
()
# print("\nTokens: ", det_tokens, '\n', flush=True)
# print("\nAttention: ", det_attention, '\n', flush=True)
# print("pad id: ", dset.pad_id, flush=True)
assert
bool
(
0
in
det_attention
)
==
bool
(
dset
.
pad_id
in
det_tokens
)
if
0
in
det_attention
:
idx_padid
=
det_tokens
.
index
(
dset
.
pad_id
)
idx_attn
=
det_attention
.
index
(
0
)
assert
idx_padid
==
idx_attn
,
(
idx_padid
,
idx_attn
)
# text = dset.decode_tokens(det_tokens)
# print(text, flush=True)
# print("Token shape: ", tokens.shape, flush=True)
# [batch_size x k x seq_length]
topk_block_tokens
,
topk_block_attention_mask
=
self
.
retriever
.
retrieve_evidence_blocks
(
tokens
,
attention_mask
,
query_block_indices
=
query_block_indices
,
include_null_doc
=
True
)
# print("Top k block shape: ", topk_block_tokens.shape, flush=True)
batch_size
=
tokens
.
shape
[
0
]
# create a copy in case it needs to be returned
ret_topk_block_tokens
=
np
.
array
(
topk_block_tokens
)
seq_length
=
topk_block_tokens
.
shape
[
2
]
long_tensor
=
torch
.
cuda
.
LongTensor
topk_block_tokens
=
long_tensor
(
topk_block_tokens
).
reshape
(
-
1
,
seq_length
)
topk_block_attention_mask
=
long_tensor
(
topk_block_attention_mask
).
reshape
(
-
1
,
seq_length
)
# print('Block token shape: ', topk_block_tokens.shape, flush=True)
# [batch_size x k x embed_size]
true_model
=
self
.
retriever
.
ict_model
.
module
.
module
fresh_block_logits
=
mpu
.
checkpoint
(
true_model
.
embed_block
,
topk_block_tokens
,
topk_block_attention_mask
)
fresh_block_logits
=
fresh_block_logits
.
reshape
(
batch_size
,
self
.
top_k
,
-
1
)
# print('Fresh block logits shape: ', fresh_block_logits.shape, flush=True)
# [batch_size x embed_size x 1]
query_logits
=
mpu
.
checkpoint
(
true_model
.
embed_query
,
tokens
,
attention_mask
).
unsqueeze
(
2
)
# print('Query logits shape: ', query_logits.shape, flush=True)
# [batch_size x k]
fresh_block_scores
=
torch
.
matmul
(
fresh_block_logits
,
query_logits
).
squeeze
()
# print('Block score shape: ', fresh_block_scores.shape, flush=True)
block_probs
=
F
.
softmax
(
fresh_block_scores
,
dim
=
1
)
# [batch_size * k x seq_length]
tokens
=
torch
.
stack
([
tokens
.
unsqueeze
(
1
)]
*
self
.
top_k
,
dim
=
1
).
reshape
(
-
1
,
seq_length
)
#assert all(tokens[i] == tokens[0] for i in range(self.top_k))
#assert all(tokens[i] == tokens[self.top_k] for i in range(self.top_k, 2 * self.top_k))
#assert not any(tokens[i] == tokens[0] for i in range(self.top_k, batch_size * self.top_k))
attention_mask
=
torch
.
stack
([
attention_mask
.
unsqueeze
(
1
)]
*
self
.
top_k
,
dim
=
1
).
reshape
(
-
1
,
seq_length
)
# [batch_size * k x 2 * seq_length]
lm_input_batch_shape
=
(
batch_size
*
self
.
top_k
,
2
*
seq_length
)
all_tokens
=
torch
.
zeros
(
lm_input_batch_shape
).
long
().
cuda
()
all_attention_mask
=
all_tokens
.
clone
()
all_token_types
=
all_tokens
.
clone
()
#all_tokens = torch.cat((tokens, topk_block_tokens), axis=1)
#all_attention_mask = torch.cat((attention_mask, topk_block_attention_mask), axis=1)
#all_token_types = torch.zeros(all_tokens.shape).type(torch.int64).cuda()
query_lengths
=
torch
.
sum
(
attention_mask
,
axis
=
1
)
# all blocks (including null ones) will have two SEP tokens
block_sep_indices
=
(
topk_block_tokens
==
dset
.
sep_id
).
nonzero
().
reshape
(
batch_size
*
self
.
top_k
,
2
,
2
)
# block body starts after the first SEP
block_starts
=
block_sep_indices
[:,
0
,
1
]
+
1
# block body ends after the second SEP
block_ends
=
block_sep_indices
[:,
1
,
1
]
+
1
# block_lengths = torch.sum(topk_block_attention_mask, axis=1)
for
row_num
in
range
(
all_tokens
.
shape
[
0
]):
q_len
=
query_lengths
[
row_num
]
b_start
=
block_starts
[
row_num
]
b_end
=
block_ends
[
row_num
]
# new tokens = CLS + query + SEP + block + SEP
new_tokens_length
=
q_len
+
b_end
-
b_start
# splice query and block tokens accordingly
all_tokens
[
row_num
,
:
q_len
]
=
tokens
[
row_num
,
:
q_len
]
all_tokens
[
row_num
,
q_len
:
new_tokens_length
]
=
topk_block_tokens
[
row_num
,
b_start
:
b_end
]
all_tokens
[
row_num
,
new_tokens_length
:]
=
self
.
retriever
.
ict_dataset
.
pad_id
# print(dset.decode_tokens(detach(all_tokens[row_num]).tolist()), '\n', flush=True)
all_attention_mask
[
row_num
,
:
new_tokens_length
]
=
1
all_attention_mask
[
row_num
,
new_tokens_length
:]
=
0
# [batch_size x k x 2 * seq_length x vocab_size]
lm_logits
,
_
=
self
.
lm_model
.
forward
(
all_tokens
,
all_attention_mask
,
all_token_types
)
lm_logits
=
lm_logits
.
reshape
(
batch_size
,
self
.
top_k
,
2
*
seq_length
,
-
1
)
if
return_topk_block_tokens
:
return
lm_logits
,
block_probs
,
ret_topk_block_tokens
return
lm_logits
,
block_probs
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_
=
{}
state_dict_
[
self
.
_lm_key
]
=
self
.
lm_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
state_dict_
[
self
.
_retriever_key
]
=
self
.
retriever
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Load the state dicts of each of the models"""
self
.
lm_model
.
load_state_dict
(
state_dict
[
self
.
_lm_key
],
strict
)
self
.
retriever
.
load_state_dict
(
state_dict
[
self
.
_retriever_key
],
strict
)
class
REALMRetriever
(
MegatronModule
):
"""Retriever which uses a pretrained ICTBertModel and a HashedIndex"""
def
__init__
(
self
,
ict_model
,
ict_dataset
,
block_data
,
hashed_index
,
top_k
=
5
):
super
(
REALMRetriever
,
self
).
__init__
()
self
.
ict_model
=
ict_model
self
.
ict_dataset
=
ict_dataset
self
.
block_data
=
block_data
self
.
hashed_index
=
hashed_index
self
.
top_k
=
top_k
self
.
_ict_key
=
'ict_model'
def
reload_index
(
self
):
args
=
get_args
()
self
.
block_data
=
BlockData
.
load_from_file
(
args
.
block_data_path
)
print
(
"resetting index"
,
flush
=
True
)
self
.
hashed_index
.
reset_index
()
self
.
hashed_index
.
add_block_embed_data
(
self
.
block_data
)
def
prep_query_text_for_retrieval
(
self
,
query_text
):
padless_max_len
=
self
.
ict_dataset
.
max_seq_length
-
2
query_tokens
=
self
.
ict_dataset
.
encode_text
(
query_text
)[:
padless_max_len
]
query_tokens
,
query_pad_mask
=
self
.
ict_dataset
.
concat_and_pad_tokens
(
query_tokens
)
query_tokens
=
torch
.
cuda
.
LongTensor
(
np
.
array
(
query_tokens
).
reshape
(
1
,
-
1
))
query_pad_mask
=
torch
.
cuda
.
LongTensor
(
np
.
array
(
query_pad_mask
).
reshape
(
1
,
-
1
))
return
query_tokens
,
query_pad_mask
def
retrieve_evidence_blocks_text
(
self
,
query_text
):
"""Get the top k evidence blocks for query_text in text form"""
print
(
"-"
*
100
)
print
(
"Query: "
,
query_text
)
query_tokens
,
query_pad_mask
=
self
.
prep_query_text_for_retrieval
(
query_text
)
topk_block_tokens
,
_
=
self
.
retrieve_evidence_blocks
(
query_tokens
,
query_pad_mask
)
for
i
,
block
in
enumerate
(
topk_block_tokens
[
0
]):
block_text
=
self
.
ict_dataset
.
decode_tokens
(
block
)
print
(
'
\n
> Block {}: {}'
.
format
(
i
,
block_text
))
def
retrieve_evidence_blocks
(
self
,
query_tokens
,
query_pad_mask
,
query_block_indices
=
None
,
include_null_doc
=
False
):
"""Embed blocks to be used in a forward pass"""
with
torch
.
no_grad
():
if
hasattr
(
self
.
ict_model
,
'module'
):
true_model
=
self
.
ict_model
.
module
if
hasattr
(
true_model
,
'module'
):
true_model
=
true_model
.
module
else
:
true_model
=
self
.
ict_model
# print("true model: ", true_model, flush=True)
query_embeds
=
self
.
ict_model
(
query_tokens
,
query_pad_mask
,
None
,
None
,
only_query
=
True
)
_
,
block_indices
=
self
.
hashed_index
.
search_mips_index
(
query_embeds
,
top_k
=
self
.
top_k
,
reconstruct
=
False
)
all_topk_tokens
,
all_topk_pad_masks
=
[],
[]
# this will result in no candidate exclusion
if
query_block_indices
is
None
:
query_block_indices
=
[
-
1
]
*
len
(
block_indices
)
top_k_offset
=
int
(
include_null_doc
)
for
query_idx
,
indices
in
enumerate
(
block_indices
):
# [k x meta_dim]
# exclude trivial candidate if it appears, else just trim the weakest in the top-k
topk_metas
=
[
self
.
block_data
.
meta_data
[
idx
]
for
idx
in
indices
if
idx
!=
query_block_indices
[
query_idx
]]
topk_block_data
=
[
self
.
ict_dataset
.
get_block
(
*
block_meta
)
for
block_meta
in
topk_metas
[:
self
.
top_k
-
top_k_offset
]]
if
include_null_doc
:
topk_block_data
.
append
(
self
.
ict_dataset
.
get_null_block
())
topk_tokens
,
topk_pad_masks
=
zip
(
*
topk_block_data
)
all_topk_tokens
.
append
(
np
.
array
(
topk_tokens
))
all_topk_pad_masks
.
append
(
np
.
array
(
topk_pad_masks
))
# [batch_size x k x seq_length]
return
np
.
array
(
all_topk_tokens
),
np
.
array
(
all_topk_pad_masks
)
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_
=
{}
state_dict_
[
self
.
_ict_key
]
=
self
.
ict_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Load the state dicts of each of the models"""
self
.
ict_model
.
load_state_dict
(
state_dict
[
self
.
_ict_key
],
strict
)
class
ICTBertModel
(
MegatronModule
):
class
ICTBertModel
(
MegatronModule
):
"""Bert-based module for Inverse Cloze task."""
"""Bert-based module for Inverse Cloze task."""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -341,10 +54,6 @@ class ICTBertModel(MegatronModule):
...
@@ -341,10 +54,6 @@ class ICTBertModel(MegatronModule):
block_logits
=
self
.
embed_block
(
block_tokens
,
block_attention_mask
)
block_logits
=
self
.
embed_block
(
block_tokens
,
block_attention_mask
)
return
query_logits
,
block_logits
return
query_logits
,
block_logits
# [batch x embed] * [embed x batch]
# retrieval_scores = query_logits.matmul(torch.transpose(block_logits, 0, 1))
# return retrieval_scores
def
embed_query
(
self
,
query_tokens
,
query_attention_mask
):
def
embed_query
(
self
,
query_tokens
,
query_attention_mask
):
"""Embed a batch of tokens using the query model"""
"""Embed a batch of tokens using the query model"""
if
self
.
use_query_model
:
if
self
.
use_query_model
:
...
@@ -391,10 +100,8 @@ class ICTBertModel(MegatronModule):
...
@@ -391,10 +100,8 @@ class ICTBertModel(MegatronModule):
state_dict
[
self
.
_block_key
],
strict
=
strict
)
state_dict
[
self
.
_block_key
],
strict
=
strict
)
def
init_state_dict_from_bert
(
self
):
def
init_state_dict_from_bert
(
self
):
"""Initialize the state from a pretrained BERT model on iteration zero of ICT pretraining"""
args
=
get_args
()
args
=
get_args
()
import
os
from
megatron
import
mpu
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
bert_load
)
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
bert_load
)
if
not
os
.
path
.
isfile
(
tracker_filename
):
if
not
os
.
path
.
isfile
(
tracker_filename
):
raise
FileNotFoundError
(
"Could not find BERT load for ICT"
)
raise
FileNotFoundError
(
"Could not find BERT load for ICT"
)
...
@@ -412,8 +119,11 @@ class ICTBertModel(MegatronModule):
...
@@ -412,8 +119,11 @@ class ICTBertModel(MegatronModule):
except
BaseException
:
except
BaseException
:
raise
ValueError
(
"Could not load checkpoint"
)
raise
ValueError
(
"Could not load checkpoint"
)
# load the LM state dict into each model
model_dict
=
state_dict
[
'model'
][
'language_model'
]
model_dict
=
state_dict
[
'model'
][
'language_model'
]
self
.
query_model
.
language_model
.
load_state_dict
(
model_dict
)
self
.
query_model
.
language_model
.
load_state_dict
(
model_dict
)
self
.
block_model
.
language_model
.
load_state_dict
(
model_dict
)
self
.
block_model
.
language_model
.
load_state_dict
(
model_dict
)
# give each model the same ict_head to begin with as well
query_ict_head_state_dict
=
self
.
state_dict_for_save_checkpoint
()[
self
.
_query_key
][
'ict_head'
]
query_ict_head_state_dict
=
self
.
state_dict_for_save_checkpoint
()[
self
.
_query_key
][
'ict_head'
]
self
.
block_model
.
ict_head
.
load_state_dict
(
query_ict_head_state_dict
)
self
.
block_model
.
ict_head
.
load_state_dict
(
query_ict_head_state_dict
)
megatron/mpu/data.py
View file @
32bb4edc
...
@@ -78,7 +78,7 @@ def broadcast_data(keys, data, datatype):
...
@@ -78,7 +78,7 @@ def broadcast_data(keys, data, datatype):
members of the same model parallel group.
members of the same model parallel group.
Arguments:
Arguments:
keys: list of keys in the data dictionary to be broadcasted
keys: list of keys in the data di
s
ctionary to be broadcasted
data: data dictionary of string keys and cpu tensor values.
data: data dictionary of string keys and cpu tensor values.
datatype: torch data type of all tensors in data associated
datatype: torch data type of all tensors in data associated
with keys.
with keys.
...
...
megatron/mpu/initialize.py
View file @
32bb4edc
...
@@ -16,7 +16,6 @@
...
@@ -16,7 +16,6 @@
"""Model and data parallel groups."""
"""Model and data parallel groups."""
import
datetime
import
torch
import
torch
from
.utils
import
ensure_divisibility
from
.utils
import
ensure_divisibility
...
@@ -27,11 +26,6 @@ _MODEL_PARALLEL_GROUP = None
...
@@ -27,11 +26,6 @@ _MODEL_PARALLEL_GROUP = None
# Data parallel group that the current rank belongs to.
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP
=
None
_DATA_PARALLEL_GROUP
=
None
_GLOO_COMM_GROUP
=
None
_TRAIN_GROUP
=
None
_INDEX_GROUP
=
None
_INDEX_READY
=
None
# These values enable us to change the mpu sizes on the fly.
# These values enable us to change the mpu sizes on the fly.
_MPU_WORLD_SIZE
=
None
_MPU_WORLD_SIZE
=
None
_MPU_RANK
=
None
_MPU_RANK
=
None
...
@@ -102,13 +96,6 @@ def get_model_parallel_group():
...
@@ -102,13 +96,6 @@ def get_model_parallel_group():
return
_MODEL_PARALLEL_GROUP
return
_MODEL_PARALLEL_GROUP
def
set_model_parallel_group
(
group
):
global
_MODEL_PARALLEL_GROUP
assert
_MODEL_PARALLEL_GROUP
is
None
,
\
'model parallel group has already been initialized'
_MODEL_PARALLEL_GROUP
=
group
def
get_data_parallel_group
():
def
get_data_parallel_group
():
"""Get the data parallel group the caller rank belongs to."""
"""Get the data parallel group the caller rank belongs to."""
assert
_DATA_PARALLEL_GROUP
is
not
None
,
\
assert
_DATA_PARALLEL_GROUP
is
not
None
,
\
...
@@ -116,13 +103,6 @@ def get_data_parallel_group():
...
@@ -116,13 +103,6 @@ def get_data_parallel_group():
return
_DATA_PARALLEL_GROUP
return
_DATA_PARALLEL_GROUP
def
set_data_parallel_group
(
group
):
global
_DATA_PARALLEL_GROUP
assert
_DATA_PARALLEL_GROUP
is
None
,
\
'data parallel group has already been initialized'
_DATA_PARALLEL_GROUP
=
group
def
set_model_parallel_world_size
(
world_size
):
def
set_model_parallel_world_size
(
world_size
):
"""Set the model parallel size"""
"""Set the model parallel size"""
global
_MPU_WORLD_SIZE
global
_MPU_WORLD_SIZE
...
@@ -175,40 +155,3 @@ def destroy_model_parallel():
...
@@ -175,40 +155,3 @@ def destroy_model_parallel():
_MODEL_PARALLEL_GROUP
=
None
_MODEL_PARALLEL_GROUP
=
None
global
_DATA_PARALLEL_GROUP
global
_DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP
=
None
_DATA_PARALLEL_GROUP
=
None
def
init_realm_groups
(
max_training_rank
,
world_size
):
global
_GLOO_COMM_GROUP
_GLOO_COMM_GROUP
=
torch
.
distributed
.
new_group
(
list
(
range
(
world_size
)),
backend
=
"gloo"
,
timeout
=
datetime
.
timedelta
(
0
,
7200
))
global
_TRAIN_GROUP
_TRAIN_GROUP
=
torch
.
distributed
.
new_group
(
list
(
range
(
max_training_rank
)))
global
_INDEX_GROUP
_INDEX_GROUP
=
torch
.
distributed
.
new_group
(
list
(
range
(
max_training_rank
,
world_size
)))
global
_INDEX_READY
_INDEX_READY
=
torch
.
zeros
(
1
)
def
get_gloo_comm_group
():
global
_GLOO_COMM_GROUP
assert
_GLOO_COMM_GROUP
is
not
None
return
_GLOO_COMM_GROUP
def
get_train_group
():
global
_TRAIN_GROUP
assert
_TRAIN_GROUP
is
not
None
return
_TRAIN_GROUP
def
get_index_group
():
global
_INDEX_GROUP
assert
_INDEX_GROUP
is
not
None
return
_INDEX_GROUP
def
get_index_ready
():
global
_INDEX_READY
assert
_INDEX_READY
is
not
None
return
_INDEX_READY
megatron/training.py
View file @
32bb4edc
...
@@ -18,8 +18,6 @@
...
@@ -18,8 +18,6 @@
from
datetime
import
datetime
from
datetime
import
datetime
import
math
import
math
import
sys
import
sys
import
time
import
torch
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
apex.optimizers
import
FusedAdam
as
Adam
from
apex.optimizers
import
FusedAdam
as
Adam
...
@@ -37,19 +35,14 @@ from megatron.initialize import initialize_megatron
...
@@ -37,19 +35,14 @@ from megatron.initialize import initialize_megatron
from
megatron.learning_rates
import
AnnealingLR
from
megatron.learning_rates
import
AnnealingLR
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
get_params_for_weight_decay_optimization
from
megatron.model
import
get_params_for_weight_decay_optimization
from
megatron.mpu.initialize
import
get_index_ready
,
get_train_group
,
get_data_parallel_group
,
get_gloo_comm_group
from
megatron.model.realm_model
import
ICTBertModel
from
megatron.model.realm_model
import
ICTBertModel
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
make_data_loader
from
megatron.utils
import
make_data_loader
from
megatron.utils
import
report_memory
from
megatron.utils
import
report_memory
INDEX_READY
=
None
def
pretrain
(
train_valid_test_dataset_provider
,
model_provider
,
def
pretrain
(
train_valid_test_dataset_provider
,
model_provider
,
forward_step_func
,
extra_args_provider
=
None
,
args_defaults
=
{},
forward_step_func
,
extra_args_provider
=
None
,
args_defaults
=
{}):
initializer_func
=
None
):
"""Main training program.
"""Main training program.
This function will run the followings in the order provided:
This function will run the followings in the order provided:
...
@@ -75,14 +68,8 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
...
@@ -75,14 +68,8 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
"""
"""
# Initalize and get arguments, timers, and Tensorboard writer.
# Initalize and get arguments, timers, and Tensorboard writer.
if
initializer_func
is
None
:
initialize_megatron
(
extra_args_provider
=
extra_args_provider
,
initialize_megatron
(
extra_args_provider
=
extra_args_provider
,
args_defaults
=
args_defaults
)
args_defaults
=
args_defaults
)
else
:
initializer_func
(
extra_args_provider
=
extra_args_provider
,
args_defaults
=
args_defaults
)
global
INDEX_READY
INDEX_READY
=
get_index_ready
()
args
=
get_args
()
args
=
get_args
()
timers
=
get_timers
()
timers
=
get_timers
()
...
@@ -232,10 +219,8 @@ def setup_model_and_optimizer(model_provider_func):
...
@@ -232,10 +219,8 @@ def setup_model_and_optimizer(model_provider_func):
args
.
iteration
=
0
args
.
iteration
=
0
if
args
.
iteration
==
0
and
isinstance
(
model
.
module
.
module
,
ICTBertModel
):
if
args
.
iteration
==
0
and
isinstance
(
model
.
module
.
module
,
ICTBertModel
):
print
(
"
Yes, located IC
T model"
,
flush
=
True
)
print
(
"
Initializing ICT from pretrained BER
T model"
,
flush
=
True
)
model
.
module
.
module
.
init_state_dict_from_bert
()
model
.
module
.
module
.
init_state_dict_from_bert
()
elif
args
.
iteration
==
0
:
print
(
"Ooops"
,
flush
=
True
)
return
model
,
optimizer
,
lr_scheduler
return
model
,
optimizer
,
lr_scheduler
...
@@ -244,15 +229,12 @@ def backward_step(optimizer, model, loss):
...
@@ -244,15 +229,12 @@ def backward_step(optimizer, model, loss):
"""Backward step."""
"""Backward step."""
args
=
get_args
()
args
=
get_args
()
timers
=
get_timers
()
timers
=
get_timers
()
# torch.cuda.synchronize()
# Backward pass.
# Backward pass.
#
optimizer.zero_grad(set_grads_to_None=True)
optimizer
.
zero_grad
(
set_grads_to_None
=
True
)
if
args
.
fp16
:
if
args
.
fp16
:
optimizer
.
zero_grad
(
set_grads_to_None
=
True
)
optimizer
.
backward
(
loss
,
update_master_grads
=
False
)
optimizer
.
backward
(
loss
,
update_master_grads
=
False
)
else
:
else
:
optimizer
.
zero_grad
()
loss
.
backward
()
loss
.
backward
()
# All-reduce if needed.
# All-reduce if needed.
...
@@ -261,9 +243,11 @@ def backward_step(optimizer, model, loss):
...
@@ -261,9 +243,11 @@ def backward_step(optimizer, model, loss):
model
.
allreduce_params
(
reduce_after
=
False
,
model
.
allreduce_params
(
reduce_after
=
False
,
fp32_allreduce
=
args
.
fp32_allreduce
)
fp32_allreduce
=
args
.
fp32_allreduce
)
timers
(
'allreduce'
).
stop
()
timers
(
'allreduce'
).
stop
()
# Update master gradients.
# Update master gradients.
if
args
.
fp16
:
if
args
.
fp16
:
optimizer
.
update_master_grads
()
optimizer
.
update_master_grads
()
# Clipping gradients helps prevent the exploding gradient.
# Clipping gradients helps prevent the exploding gradient.
if
args
.
clip_grad
>
0
:
if
args
.
clip_grad
>
0
:
if
not
args
.
fp16
:
if
not
args
.
fp16
:
...
@@ -283,12 +267,11 @@ def train_step(forward_step_func, data_iterator,
...
@@ -283,12 +267,11 @@ def train_step(forward_step_func, data_iterator,
loss
,
loss_reduced
=
forward_step_func
(
data_iterator
,
model
)
loss
,
loss_reduced
=
forward_step_func
(
data_iterator
,
model
)
timers
(
'forward'
).
stop
()
timers
(
'forward'
).
stop
()
# Calculate gradients, reduce across processes, and clip.
timers
(
'backward'
).
start
()
timers
(
'backward'
).
start
()
backward_step
(
optimizer
,
model
,
loss
)
backward_step
(
optimizer
,
model
,
loss
)
timers
(
'backward'
).
stop
()
timers
(
'backward'
).
stop
()
# Calculate gradients, reduce across processes, and clip.
# Update parameters.
# Update parameters.
timers
(
'optimizer'
).
start
()
timers
(
'optimizer'
).
start
()
optimizer
.
step
()
optimizer
.
step
()
...
@@ -383,54 +366,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -383,54 +366,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
timers
(
'interval time'
).
start
()
timers
(
'interval time'
).
start
()
report_memory_flag
=
True
report_memory_flag
=
True
global
INDEX_READY
print
(
'>>> Starting train()'
,
flush
=
True
)
# start off by posting a receive call which will be answered.
# synchronize for start
if
args
.
max_training_rank
is
not
None
:
torch
.
distributed
.
broadcast
(
INDEX_READY
,
0
,
group
=
get_gloo_comm_group
())
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
group
=
get_gloo_comm_group
(),
async_op
=
True
)
last_reload_iteration
=
iteration
while
iteration
<
args
.
train_iters
:
while
iteration
<
args
.
train_iters
:
if
args
.
max_training_rank
is
not
None
and
iteration
>=
last_reload_iteration
+
500
and
not
recv_handle
.
is_completed
():
time
.
sleep
(
5
)
continue
# this only applies for realm right here
if
args
.
max_training_rank
is
not
None
and
recv_handle
.
is_completed
():
# should add check that INDEX_READY == 1 but what else could be happening
true_model
=
model
if
hasattr
(
true_model
,
'module'
):
true_model
=
true_model
.
module
if
hasattr
(
true_model
,
'module'
):
true_model
=
true_model
.
module
print
(
"> Saving model and reloading index"
,
flush
=
True
)
if
args
.
rank
==
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
true_model
.
retriever
.
reload_index
()
if
args
.
rank
==
0
:
INDEX_READY
=
1
-
INDEX_READY
torch
.
cuda
.
synchronize
()
# send handle
torch
.
distributed
.
broadcast
(
INDEX_READY
,
0
,
group
=
get_gloo_comm_group
())
torch
.
distributed
.
barrier
(
get_data_parallel_group
())
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
group
=
get_gloo_comm_group
(),
async_op
=
True
)
last_reload_iteration
=
iteration
elif
iteration
<
20
:
print
(
"moving right along"
,
flush
=
True
)
# report_memory("iteration {}".format(iteration))
loss_dict
,
skipped_iter
=
train_step
(
forward_step_func
,
loss_dict
,
skipped_iter
=
train_step
(
forward_step_func
,
train_data_iterator
,
train_data_iterator
,
model
,
model
,
optimizer
,
optimizer
,
lr_scheduler
)
lr_scheduler
)
skipped_iters
+=
skipped_iter
skipped_iters
+=
skipped_iter
iteration
+=
1
iteration
+=
1
...
@@ -463,7 +404,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -463,7 +404,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
iteration
,
False
)
iteration
,
False
)
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
torch
.
distributed
.
barrier
(
get_data_parallel_group
()
)
torch
.
distributed
.
barrier
()
time_str
=
datetime
.
now
().
strftime
(
'%Y-%m-%d %H:%M:%S'
)
time_str
=
datetime
.
now
().
strftime
(
'%Y-%m-%d %H:%M:%S'
)
rank
=
torch
.
distributed
.
get_rank
()
rank
=
torch
.
distributed
.
get_rank
()
print_rank_0
(
'rank: {} | time: {} | exiting the program at '
print_rank_0
(
'rank: {} | time: {} | exiting the program at '
...
...
megatron/utils.py
View file @
32bb4edc
...
@@ -25,7 +25,6 @@ from megatron import mpu
...
@@ -25,7 +25,6 @@ from megatron import mpu
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron.checkpointing
import
save_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.mpu.initialize
import
get_data_parallel_group
from
megatron.fp16
import
FP16_Optimizer
from
megatron.fp16
import
FP16_Optimizer
...
@@ -33,13 +32,8 @@ def reduce_losses(losses):
...
@@ -33,13 +32,8 @@ def reduce_losses(losses):
"""Reduce a tensor of losses across all GPUs."""
"""Reduce a tensor of losses across all GPUs."""
reduced_losses
=
torch
.
cat
(
reduced_losses
=
torch
.
cat
(
[
loss
.
clone
().
detach
().
view
(
1
)
for
loss
in
losses
])
[
loss
.
clone
().
detach
().
view
(
1
)
for
loss
in
losses
])
torch
.
distributed
.
all_reduce
(
reduced_losses
,
group
=
get_data_parallel_group
())
torch
.
distributed
.
all_reduce
(
reduced_losses
)
args
=
get_args
()
reduced_losses
=
reduced_losses
/
torch
.
distributed
.
get_world_size
()
if
args
.
max_training_rank
is
not
None
:
num_trainers
=
args
.
max_training_rank
else
:
num_trainers
=
torch
.
distributed
.
get_world_size
()
reduced_losses
=
reduced_losses
/
num_trainers
return
reduced_losses
return
reduced_losses
...
@@ -84,7 +78,7 @@ def check_adlr_autoresume_termination(iteration, model,
...
@@ -84,7 +78,7 @@ def check_adlr_autoresume_termination(iteration, model,
args
=
get_args
()
args
=
get_args
()
autoresume
=
get_adlr_autoresume
()
autoresume
=
get_adlr_autoresume
()
# Add barrier to ensure consistnecy.
# Add barrier to ensure consistnecy.
torch
.
distributed
.
barrier
(
get_data_parallel_group
()
)
torch
.
distributed
.
barrier
()
if
autoresume
.
termination_requested
():
if
autoresume
.
termination_requested
():
if
args
.
save
:
if
args
.
save
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
...
...
pretrain_bert_ict.py
View file @
32bb4edc
...
@@ -36,6 +36,7 @@ def model_provider(only_query_model=False, only_block_model=False):
...
@@ -36,6 +36,7 @@ def model_provider(only_query_model=False, only_block_model=False):
args
=
get_args
()
args
=
get_args
()
print_rank_0
(
'building BERT models ...'
)
print_rank_0
(
'building BERT models ...'
)
# simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes
model
=
ICTBertModel
(
model
=
ICTBertModel
(
ict_head_size
=
128
,
ict_head_size
=
128
,
num_tokentypes
=
2
,
num_tokentypes
=
2
,
...
@@ -93,19 +94,16 @@ def forward_step(data_iterator, model):
...
@@ -93,19 +94,16 @@ def forward_step(data_iterator, model):
all_query_logits
=
torch
.
zeros
(
all_logits_shape
).
type
(
query_logits
.
dtype
).
cuda
()
all_query_logits
=
torch
.
zeros
(
all_logits_shape
).
type
(
query_logits
.
dtype
).
cuda
()
all_block_logits
=
all_query_logits
.
clone
().
cuda
()
all_block_logits
=
all_query_logits
.
clone
().
cuda
()
# record this processes' data and then merge with other processes below
all_query_logits
[
args
.
rank
*
batch_size
:(
args
.
rank
+
1
)
*
batch_size
]
=
query_logits
all_query_logits
[
args
.
rank
*
batch_size
:(
args
.
rank
+
1
)
*
batch_size
]
=
query_logits
all_block_logits
[
args
.
rank
*
batch_size
:(
args
.
rank
+
1
)
*
batch_size
]
=
block_logits
all_block_logits
[
args
.
rank
*
batch_size
:(
args
.
rank
+
1
)
*
batch_size
]
=
block_logits
# print(all_query_logits[:, :5], flush=True)
# print(all_block_logits[:, :5], flush=True)
dist
.
all_reduce
(
all_query_logits
)
dist
.
all_reduce
(
all_query_logits
)
dist
.
all_reduce
(
all_block_logits
)
dist
.
all_reduce
(
all_block_logits
)
# print(all_query_logits[:, :5], flush=True)
# print(all_block_logits[:, :5], flush=True)
# scores are inner products between query and block embeddings
retrieval_scores
=
all_query_logits
.
float
().
matmul
(
torch
.
transpose
(
all_block_logits
,
0
,
1
).
float
())
retrieval_scores
=
all_query_logits
.
float
().
matmul
(
torch
.
transpose
(
all_block_logits
,
0
,
1
).
float
())
softmaxed
=
F
.
softmax
(
retrieval_scores
,
dim
=
1
)
softmaxed
=
F
.
softmax
(
retrieval_scores
,
dim
=
1
)
sorted_vals
,
sorted_indices
=
torch
.
topk
(
softmaxed
,
k
=
softmaxed
.
shape
[
1
],
sorted
=
True
)
sorted_vals
,
sorted_indices
=
torch
.
topk
(
softmaxed
,
k
=
softmaxed
.
shape
[
1
],
sorted
=
True
)
def
topk_acc
(
k
):
def
topk_acc
(
k
):
...
@@ -113,11 +111,6 @@ def forward_step(data_iterator, model):
...
@@ -113,11 +111,6 @@ def forward_step(data_iterator, model):
top_accs
=
[
topk_acc
(
k
)
for
k
in
[
1
,
8
,
20
,
100
]]
top_accs
=
[
topk_acc
(
k
)
for
k
in
[
1
,
8
,
20
,
100
]]
retrieval_loss
=
torch
.
nn
.
CrossEntropyLoss
()(
retrieval_scores
,
torch
.
arange
(
global_batch_size
).
long
().
cuda
())
retrieval_loss
=
torch
.
nn
.
CrossEntropyLoss
()(
retrieval_scores
,
torch
.
arange
(
global_batch_size
).
long
().
cuda
())
# correct_probs = torch.gather(softmaxed, 1, torch.arange(global_batch_size).long().cuda().reshape(-1, 1))
# assert correct_probs[3] == softmaxed[3, 3]
# retrieval_loss = -torch.sum(torch.log(correct_probs)) / global_batch_size
reduced_losses
=
reduce_losses
([
retrieval_loss
,
*
top_accs
])
reduced_losses
=
reduce_losses
([
retrieval_loss
,
*
top_accs
])
stats_dict
=
{
stats_dict
=
{
'retrieval loss'
:
reduced_losses
[
0
],
'retrieval loss'
:
reduced_losses
[
0
],
...
...
pretrain_realm.py
deleted
100644 → 0
View file @
674814a5
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain BERT for Inverse Cloze Task"""
import
torch
import
torch.nn.functional
as
F
from
indexer
import
load_ict_checkpoint
,
get_ict_dataset
from
megatron.data.realm_index
import
BlockData
,
RandProjectionLSHIndex
,
FaissMIPSIndex
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.model
import
REALMBertModel
,
REALMRetriever
from
megatron.training
import
pretrain
from
megatron.utils
import
reduce_losses
,
report_memory
from
megatron
import
mpu
from
indexer
import
initialize_and_run_async_megatron
num_batches
=
0
def
model_provider
():
"""Build the model."""
args
=
get_args
()
print_rank_0
(
'building REALM models ...'
)
try
:
ict_model
=
load_ict_checkpoint
(
from_realm_chkpt
=
True
)
except
:
ict_model
=
load_ict_checkpoint
(
from_realm_chkpt
=
False
)
ict_dataset
=
get_ict_dataset
(
use_titles
=
False
)
all_block_data
=
BlockData
.
load_from_file
(
args
.
block_data_path
)
# hashed_index = RandProjectionLSHIndex.load_from_file(args.block_index_path)
hashed_index
=
FaissMIPSIndex
(
index_type
=
'flat_ip'
,
embed_size
=
128
,
use_gpu
=
args
.
faiss_use_gpu
)
hashed_index
.
add_block_embed_data
(
all_block_data
)
# top_k + 1 because we may need to exclude trivial candidate
retriever
=
REALMRetriever
(
ict_model
,
ict_dataset
,
all_block_data
,
hashed_index
,
args
.
block_top_k
)
model
=
REALMBertModel
(
retriever
)
return
model
def
get_batch
(
data_iterator
):
# Items and their type.
keys
=
[
'tokens'
,
'labels'
,
'loss_mask'
,
'pad_mask'
,
'query_block_indices'
]
datatype
=
torch
.
int64
# Broadcast data.
if
data_iterator
is
None
:
data
=
None
else
:
data
=
next
(
data_iterator
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
tokens
=
data_b
[
'tokens'
].
long
()
labels
=
data_b
[
'labels'
].
long
()
loss_mask
=
data_b
[
'loss_mask'
].
long
()
pad_mask
=
data_b
[
'pad_mask'
].
long
()
query_block_indices
=
data_b
[
'query_block_indices'
].
long
()
return
tokens
,
labels
,
loss_mask
,
pad_mask
,
query_block_indices
def
get_qa_batch
(
data_iterator
):
question_tokens
,
question_attention_mask
,
answer_tokens
,
answer_token_lengths
=
next
(
data_iterator
)
return
question_tokens
,
question_attention_mask
,
answer_tokens
,
answer_token_lengths
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
timers
=
get_timers
()
# Get the batch.
timers
(
'batch generator'
).
start
()
tokens
,
labels
,
loss_mask
,
pad_mask
,
query_block_indices
=
get_batch
(
data_iterator
)
timers
(
'batch generator'
).
stop
()
# Forward model.
lm_logits
,
block_probs
=
model
(
tokens
,
pad_mask
,
query_block_indices
)
with
torch
.
no_grad
():
max_retrieval_utility
,
top_retrieval_utility
,
avg_retrieval_utility
=
mpu
.
checkpoint
(
get_retrieval_utility
,
lm_logits
,
block_probs
,
labels
,
loss_mask
)
# P(y|x) = sum_z(P(y|z, x) * P(z|x))
null_block_probs
=
torch
.
mean
(
block_probs
[:,
block_probs
.
shape
[
1
]
-
1
])
block_probs
=
block_probs
.
unsqueeze
(
2
).
unsqueeze
(
3
).
expand_as
(
lm_logits
)
lm_logits
=
torch
.
sum
(
lm_logits
*
block_probs
,
dim
=
1
)[:,
:
labels
.
shape
[
1
]]
lm_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
contiguous
().
float
(),
labels
.
contiguous
())
lm_loss
=
torch
.
sum
(
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
reduced_loss
=
reduce_losses
([
lm_loss
,
max_retrieval_utility
,
top_retrieval_utility
,
avg_retrieval_utility
,
null_block_probs
])
# torch.cuda.synchronize()
return
lm_loss
,
{
'lm_loss'
:
reduced_loss
[
0
],
'max_ru'
:
reduced_loss
[
1
],
'top_ru'
:
reduced_loss
[
2
],
'avg_ru'
:
reduced_loss
[
3
],
'null_prob'
:
reduced_loss
[
4
]}
def
get_retrieval_utility
(
lm_logits
,
block_probs
,
labels
,
loss_mask
):
"""log P(y | z, x) - log P(y | null, x)"""
# [batch x seq_len x vocab_size]
lm_logits
=
lm_logits
[:,
:,
:
labels
.
shape
[
1
],
:]
#non_null_block_probs = block_probs[:, :-1]
#non_null_block_probs /= torch.sum(non_null_block_probs, axis=1, keepdim=True)
# non_null_block_probs = non_null_block_probsexpand_as(lm_logits[:, :-1, :, :])
null_block_lm_logits
=
lm_logits
[:,
-
1
,
:,
:]
null_block_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
null_block_lm_logits
.
contiguous
().
float
(),
labels
.
contiguous
())
null_block_loss
=
torch
.
sum
(
null_block_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
retrieved_block_losses
=
[]
for
block_num
in
range
(
lm_logits
.
shape
[
1
]
-
1
):
retrieved_block_lm_logits
=
lm_logits
[:,
block_num
,
:,
:]
retrieved_block_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
retrieved_block_lm_logits
.
contiguous
().
float
(),
labels
.
contiguous
())
#retrieved_block_loss_ *= non_null_block_probs[:, block_num].reshape(-1, 1)
retrieved_block_loss
=
torch
.
sum
(
retrieved_block_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
retrieved_block_losses
.
append
(
retrieved_block_loss
)
avg_retrieved_block_loss
=
torch
.
sum
(
torch
.
cuda
.
FloatTensor
(
retrieved_block_losses
))
/
(
lm_logits
.
shape
[
1
]
-
1
)
max_retrieval_utility
=
null_block_loss
-
min
(
retrieved_block_losses
)
top_retrieval_utility
=
null_block_loss
-
retrieved_block_losses
[
0
]
avg_retrieval_utility
=
null_block_loss
-
avg_retrieved_block_loss
return
max_retrieval_utility
,
top_retrieval_utility
,
avg_retrieval_utility
def
qa_forward_step
(
data_iterator
,
model
):
timers
=
get_timers
()
# this dataset interface needs to be implemented
timers
(
'batch generator'
).
start
()
question_tokens
,
question_attention_mask
,
answer_tokens
,
answer_token_lengths
=
get_qa_batch
(
data_iterator
)
timers
(
'batch generator'
).
stop
()
batch_span_logits
,
batch_loss_masks
,
block_probs
=
model
(
question_tokens
,
question_attention_mask
,
answer_tokens
,
answer_token_lengths
)
# [batch_size x k x num_spans]
block_probs
=
block_probs
.
unsqueeze
(
2
).
expand_as
(
batch_span_logits
)
batch_span_probs
=
F
.
softmax
(
batch_span_logits
,
dim
=
2
)
reduced_block_span_probs
=
torch
.
sum
(
batch_span_probs
*
block_probs
,
dim
=
1
)
qa_span_loss_
=
-
torch
.
log
(
reduced_block_span_probs
)
qa_span_loss
=
torch
.
sum
(
qa_span_loss_
.
view
(
-
1
)
*
batch_loss_masks
)
def
train_valid_test_datasets_provider
(
train_val_test_num_samples
):
"""Build train, valid and test datasets."""
args
=
get_args
()
print_rank_0
(
'> building train, validation, and test datasets '
'for BERT ...'
)
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets
(
data_prefix
=
args
.
data_path
,
data_impl
=
args
.
data_impl
,
splits_string
=
args
.
split
,
train_valid_test_num_samples
=
train_val_test_num_samples
,
max_seq_length
=
args
.
seq_length
,
masked_lm_prob
=
args
.
mask_prob
,
short_seq_prob
=
args
.
short_seq_prob
,
seed
=
args
.
seed
,
skip_warmup
=
(
not
args
.
mmap_warmup
),
dataset_type
=
'realm'
)
print_rank_0
(
"> finished creating BERT ICT datasets ..."
)
return
train_ds
,
valid_ds
,
test_ds
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
},
initializer_func
=
initialize_and_run_async_megatron
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment