Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
8e22824e
Commit
8e22824e
authored
May 24, 2020
by
Neel Kant
Browse files
Fix token alignment, add mpu checkpointing, misc training code
parent
8573ab35
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
58 additions
and
35 deletions
+58
-35
megatron/model/realm_model.py
megatron/model/realm_model.py
+43
-27
megatron/training.py
megatron/training.py
+15
-8
No files found.
megatron/model/realm_model.py
View file @
8e22824e
...
@@ -8,6 +8,8 @@ from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex
...
@@ -8,6 +8,8 @@ from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex
from
megatron.model
import
BertModel
from
megatron.model
import
BertModel
from
megatron.model.utils
import
get_linear_layer
,
init_method_normal
from
megatron.model.utils
import
get_linear_layer
,
init_method_normal
from
megatron.module
import
MegatronModule
from
megatron.module
import
MegatronModule
from
megatron.utils
import
report_memory
from
megatron
import
mpu
class
REALMAnswerSpanModel
(
MegatronModule
):
class
REALMAnswerSpanModel
(
MegatronModule
):
...
@@ -105,11 +107,11 @@ class REALMBertModel(MegatronModule):
...
@@ -105,11 +107,11 @@ class REALMBertModel(MegatronModule):
# [batch_size x k x embed_size]
# [batch_size x k x embed_size]
true_model
=
self
.
retriever
.
ict_model
.
module
.
module
true_model
=
self
.
retriever
.
ict_model
.
module
.
module
fresh_block_logits
=
true_model
.
embed_block
(
topk_block_tokens
,
topk_block_attention_mask
)
fresh_block_logits
=
mpu
.
checkpoint
(
true_model
.
embed_block
,
topk_block_tokens
,
topk_block_attention_mask
)
fresh_block_logits
=
fresh_block_logits
.
reshape
(
batch_size
,
self
.
top_k
,
-
1
)
fresh_block_logits
=
fresh_block_logits
.
reshape
(
batch_size
,
self
.
top_k
,
-
1
)
# [batch_size x embed_size x 1]
# [batch_size x embed_size x 1]
query_logits
=
true_model
.
embed_query
(
tokens
,
attention_mask
).
unsqueeze
(
2
)
query_logits
=
mpu
.
checkpoint
(
true_model
.
embed_query
,
tokens
,
attention_mask
).
unsqueeze
(
2
)
# [batch_size x k]
# [batch_size x k]
fresh_block_scores
=
torch
.
matmul
(
fresh_block_logits
,
query_logits
).
squeeze
()
fresh_block_scores
=
torch
.
matmul
(
fresh_block_logits
,
query_logits
).
squeeze
()
...
@@ -124,6 +126,22 @@ class REALMBertModel(MegatronModule):
...
@@ -124,6 +126,22 @@ class REALMBertModel(MegatronModule):
all_attention_mask
=
torch
.
cat
((
attention_mask
,
topk_block_attention_mask
),
axis
=
1
)
all_attention_mask
=
torch
.
cat
((
attention_mask
,
topk_block_attention_mask
),
axis
=
1
)
all_token_types
=
torch
.
zeros
(
all_tokens
.
shape
).
type
(
torch
.
int64
).
cuda
()
all_token_types
=
torch
.
zeros
(
all_tokens
.
shape
).
type
(
torch
.
int64
).
cuda
()
# re-align tokens to be contiguous
query_lengths
=
torch
.
sum
(
attention_mask
,
axis
=
1
)
block_lengths
=
torch
.
sum
(
topk_block_attention_mask
,
axis
=
1
)
for
row_num
in
range
(
all_tokens
.
shape
[
0
]):
qlen
=
query_lengths
[
row_num
]
blen
=
block_lengths
[
row_num
]
# disregard the CLS token from the block tokens
new_tokens_length
=
qlen
+
blen
-
1
all_tokens
[
row_num
,
:
qlen
]
=
tokens
[
row_num
,
:
qlen
]
all_tokens
[
row_num
,
qlen
:
new_tokens_length
]
=
tokens
[
row_num
,
1
:
blen
]
all_tokens
[
row_num
,
new_tokens_length
:]
=
self
.
retriever
.
ict_dataset
.
pad_id
all_attention_mask
[
row_num
,
:
new_tokens_length
]
=
1
all_attention_mask
[
row_num
,
new_tokens_length
:]
=
0
# [batch_size x k x 2 * seq_length x vocab_size]
# [batch_size x k x 2 * seq_length x vocab_size]
lm_logits
,
_
=
self
.
lm_model
.
forward
(
all_tokens
,
all_attention_mask
,
all_token_types
)
lm_logits
,
_
=
self
.
lm_model
.
forward
(
all_tokens
,
all_attention_mask
,
all_token_types
)
lm_logits
=
lm_logits
.
reshape
(
batch_size
,
self
.
top_k
,
2
*
seq_length
,
-
1
)
lm_logits
=
lm_logits
.
reshape
(
batch_size
,
self
.
top_k
,
2
*
seq_length
,
-
1
)
...
@@ -163,11 +181,9 @@ class REALMRetriever(MegatronModule):
...
@@ -163,11 +181,9 @@ class REALMRetriever(MegatronModule):
def
reload_index
(
self
):
def
reload_index
(
self
):
args
=
get_args
()
args
=
get_args
()
print
(
"loading from file"
,
flush
=
True
)
self
.
block_data
=
BlockData
.
load_from_file
(
args
.
block_data_path
)
self
.
block_data
=
BlockData
.
load_from_file
(
args
.
block_data_path
)
print
(
"resetting index"
,
flush
=
True
)
print
(
"resetting index"
,
flush
=
True
)
self
.
hashed_index
.
reset_index
()
self
.
hashed_index
.
reset_index
()
print
(
"adding block data"
,
flush
=
True
)
self
.
hashed_index
.
add_block_embed_data
(
self
.
block_data
)
self
.
hashed_index
.
add_block_embed_data
(
self
.
block_data
)
def
prep_query_text_for_retrieval
(
self
,
query_text
):
def
prep_query_text_for_retrieval
(
self
,
query_text
):
...
@@ -201,29 +217,29 @@ class REALMRetriever(MegatronModule):
...
@@ -201,29 +217,29 @@ class REALMRetriever(MegatronModule):
true_model
=
self
.
ict_model
true_model
=
self
.
ict_model
# print("true model: ", true_model, flush=True)
# print("true model: ", true_model, flush=True)
query_embeds
=
detach
(
self
.
ict_model
(
query_tokens
,
query_pad_mask
,
None
,
None
,
only_query
=
True
)
)
query_embeds
=
self
.
ict_model
(
query_tokens
,
query_pad_mask
,
None
,
None
,
only_query
=
True
)
_
,
block_indices
=
self
.
hashed_index
.
search_mips_index
(
query_embeds
,
top_k
=
self
.
top_k
,
reconstruct
=
False
)
_
,
block_indices
=
self
.
hashed_index
.
search_mips_index
(
query_embeds
,
top_k
=
self
.
top_k
,
reconstruct
=
False
)
all_topk_tokens
,
all_topk_pad_masks
=
[],
[]
all_topk_tokens
,
all_topk_pad_masks
=
[],
[]
# this will result in no candidate exclusion
# this will result in no candidate exclusion
if
query_block_indices
is
None
:
if
query_block_indices
is
None
:
query_block_indices
=
[
-
1
]
*
len
(
block_indices
)
query_block_indices
=
[
-
1
]
*
len
(
block_indices
)
top_k_offset
=
int
(
include_null_doc
)
top_k_offset
=
int
(
include_null_doc
)
for
query_idx
,
indices
in
enumerate
(
block_indices
):
for
query_idx
,
indices
in
enumerate
(
block_indices
):
# [k x meta_dim]
# [k x meta_dim]
# exclude trivial candidate if it appears, else just trim the weakest in the top-k
# exclude trivial candidate if it appears, else just trim the weakest in the top-k
topk_metas
=
[
self
.
block_data
.
meta_data
[
idx
]
for
idx
in
indices
if
idx
!=
query_block_indices
[
query_idx
]]
topk_metas
=
[
self
.
block_data
.
meta_data
[
idx
]
for
idx
in
indices
if
idx
!=
query_block_indices
[
query_idx
]]
topk_block_data
=
[
self
.
ict_dataset
.
get_block
(
*
block_meta
)
for
block_meta
in
topk_metas
[:
self
.
top_k
-
top_k_offset
]]
topk_block_data
=
[
self
.
ict_dataset
.
get_block
(
*
block_meta
)
for
block_meta
in
topk_metas
[:
self
.
top_k
-
top_k_offset
]]
if
include_null_doc
:
if
include_null_doc
:
topk_block_data
.
append
(
self
.
ict_dataset
.
get_null_block
())
topk_block_data
.
append
(
self
.
ict_dataset
.
get_null_block
())
topk_tokens
,
topk_pad_masks
=
zip
(
*
topk_block_data
)
topk_tokens
,
topk_pad_masks
=
zip
(
*
topk_block_data
)
all_topk_tokens
.
append
(
np
.
array
(
topk_tokens
))
all_topk_tokens
.
append
(
np
.
array
(
topk_tokens
))
all_topk_pad_masks
.
append
(
np
.
array
(
topk_pad_masks
))
all_topk_pad_masks
.
append
(
np
.
array
(
topk_pad_masks
))
# [batch_size x k x seq_length]
# [batch_size x k x seq_length]
return
np
.
array
(
all_topk_tokens
),
np
.
array
(
all_topk_pad_masks
)
return
np
.
array
(
all_topk_tokens
),
np
.
array
(
all_topk_pad_masks
)
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
...
...
megatron/training.py
View file @
8e22824e
...
@@ -374,12 +374,16 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -374,12 +374,16 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
timers
(
'interval time'
).
start
()
timers
(
'interval time'
).
start
()
report_memory_flag
=
True
report_memory_flag
=
True
global
INDEX_READY
global
INDEX_READY
print
(
'>>> Starting train()'
,
flush
=
True
)
# start off by posting a receive call which will be answered.
# start off by posting a receive call which will be answered.
# synchronize for start
torch
.
distributed
.
broadcast
(
INDEX_READY
,
0
,
group
=
get_gloo_comm_group
())
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
group
=
get_gloo_comm_group
(),
async_op
=
True
)
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
group
=
get_gloo_comm_group
(),
async_op
=
True
)
last_reload_iteration
=
iteration
while
iteration
<
args
.
train_iters
:
while
iteration
<
args
.
train_iters
:
# this only applies for realm right here
# this only applies for realm right here
if
args
.
max_training_rank
is
not
None
and
recv_handle
.
is_completed
():
if
args
.
max_training_rank
is
not
None
and
recv_handle
.
is_completed
()
and
iteration
>=
last_reload_iteration
+
500
:
# should add check that INDEX_READY == 1 but what else could be happening
# should add check that INDEX_READY == 1 but what else could be happening
true_model
=
model
true_model
=
model
if
hasattr
(
true_model
,
'module'
):
if
hasattr
(
true_model
,
'module'
):
...
@@ -388,20 +392,23 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -388,20 +392,23 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
true_model
=
true_model
.
module
true_model
=
true_model
.
module
print
(
">>>>>>> starting to reload index"
,
flush
=
True
)
print
(
"> Saving model and reloading index"
,
flush
=
True
)
true_model
.
retriever
.
reload_index
()
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
true_model
.
retriever
.
reload_index
()
if
args
.
rank
==
0
:
if
args
.
rank
==
0
:
INDEX_READY
=
1
-
INDEX_READY
INDEX_READY
=
1
-
INDEX_READY
print
(
">>> Switched index ready"
,
flush
=
True
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
send_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
0
,
group
=
get_gloo_comm_group
())
# send handle
torch
.
distributed
.
broadcast
(
INDEX_READY
,
0
,
group
=
get_gloo_comm_group
())
torch
.
distributed
.
barrier
(
get_data_parallel_group
())
torch
.
distributed
.
barrier
(
get_data_parallel_group
())
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
group
=
get_gloo_comm_group
(),
async_op
=
True
)
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
group
=
get_gloo_comm_group
(),
async_op
=
True
)
elif
iteration
<
100
:
last_reload_iteration
=
iteration
elif
iteration
<
20
:
print
(
"moving right along"
,
flush
=
True
)
print
(
"moving right along"
,
flush
=
True
)
# report_memory("iteration {}".format(iteration))
loss_dict
,
skipped_iter
=
train_step
(
forward_step_func
,
loss_dict
,
skipped_iter
=
train_step
(
forward_step_func
,
train_data_iterator
,
train_data_iterator
,
model
,
model
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment