Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
2fd4ea6c
Commit
2fd4ea6c
authored
May 25, 2020
by
Neel Kant
Browse files
Corrected realm example building, misc improvements for async concurrency
parent
8e22824e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
73 additions
and
21 deletions
+73
-21
indexer.py
indexer.py
+2
-1
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+3
-2
megatron/model/realm_model.py
megatron/model/realm_model.py
+59
-14
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+0
-1
megatron/training.py
megatron/training.py
+8
-2
pretrain_realm.py
pretrain_realm.py
+1
-1
No files found.
indexer.py
View file @
2fd4ea6c
...
...
@@ -95,6 +95,7 @@ def setup_realm_groups_and_vars():
class
IndexBuilder
(
object
):
def
__init__
(
self
):
args
=
get_args
()
self
.
debug
=
args
.
debug
self
.
rank
=
args
.
rank
self
.
model
=
None
self
.
dataloader
=
None
...
...
@@ -287,6 +288,6 @@ def get_one_epoch_dataloader(dataset, batch_size=None):
if
__name__
==
"__main__"
:
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
index_builder
=
Basic
IndexBuilder
()
index_builder
=
IndexBuilder
()
index_builder
.
build_and_save_index
()
megatron/data/realm_dataset.py
View file @
2fd4ea6c
...
...
@@ -5,7 +5,7 @@ import numpy as np
from
torch.utils.data
import
Dataset
from
megatron
import
get_tokenizer
from
megatron.data.realm_dataset_utils
import
build_realm_training_sample
,
get_block_samples_mapping
from
megatron.data.realm_dataset_utils
import
build_realm_training_sample
,
get_block_samples_mapping
,
join_str_list
class
REALMDataset
(
Dataset
):
...
...
@@ -136,7 +136,8 @@ class ICTDataset(Dataset):
def
decode_tokens
(
self
,
token_ids
):
tokens
=
self
.
tokenizer
.
tokenizer
.
convert_ids_to_tokens
(
token_ids
)
return
' '
.
join
(
token
for
token
in
tokens
if
token
!=
'[PAD]'
)
non_pads
=
[
t
for
t
in
tokens
if
t
!=
'[PAD]'
]
return
join_str_list
(
non_pads
)
def
get_block
(
self
,
start_idx
,
end_idx
,
doc_idx
):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
...
...
megatron/model/realm_model.py
View file @
2fd4ea6c
...
...
@@ -94,51 +94,96 @@ class REALMBertModel(MegatronModule):
self
.
_retriever_key
=
'retriever'
def
forward
(
self
,
tokens
,
attention_mask
,
query_block_indices
,
return_topk_block_tokens
=
False
):
# print("\nNEW FORWARD", '-' * 100, flush=True)
dset
=
self
.
retriever
.
ict_dataset
det_tokens
=
detach
(
tokens
)[
0
].
tolist
()
det_attention
=
detach
(
attention_mask
)[
0
].
tolist
()
# print("\nTokens: ", det_tokens, '\n', flush=True)
# print("\nAttention: ", det_attention, '\n', flush=True)
# print("pad id: ", dset.pad_id, flush=True)
assert
bool
(
0
in
det_attention
)
==
bool
(
dset
.
pad_id
in
det_tokens
)
if
0
in
det_attention
:
idx_padid
=
det_tokens
.
index
(
dset
.
pad_id
)
idx_attn
=
det_attention
.
index
(
0
)
assert
idx_padid
==
idx_attn
,
(
idx_padid
,
idx_attn
)
# text = dset.decode_tokens(det_tokens)
# print(text, flush=True)
# print("Token shape: ", tokens.shape, flush=True)
# [batch_size x k x seq_length]
topk_block_tokens
,
topk_block_attention_mask
=
self
.
retriever
.
retrieve_evidence_blocks
(
tokens
,
attention_mask
,
query_block_indices
=
query_block_indices
,
include_null_doc
=
True
)
# print("Top k block shape: ", topk_block_tokens.shape, flush=True)
batch_size
=
tokens
.
shape
[
0
]
# create a copy in case it needs to be returned
ret_topk_block_tokens
=
np
.
array
(
topk_block_tokens
)
seq_length
=
topk_block_tokens
.
shape
[
2
]
topk_block_tokens
=
torch
.
cuda
.
LongTensor
(
topk_block_tokens
).
reshape
(
-
1
,
seq_length
)
topk_block_attention_mask
=
torch
.
cuda
.
LongTensor
(
topk_block_attention_mask
).
reshape
(
-
1
,
seq_length
)
long_tensor
=
torch
.
cuda
.
LongTensor
topk_block_tokens
=
long_tensor
(
topk_block_tokens
).
reshape
(
-
1
,
seq_length
)
topk_block_attention_mask
=
long_tensor
(
topk_block_attention_mask
).
reshape
(
-
1
,
seq_length
)
# print('Block token shape: ', topk_block_tokens.shape, flush=True)
# [batch_size x k x embed_size]
true_model
=
self
.
retriever
.
ict_model
.
module
.
module
fresh_block_logits
=
mpu
.
checkpoint
(
true_model
.
embed_block
,
topk_block_tokens
,
topk_block_attention_mask
)
fresh_block_logits
=
fresh_block_logits
.
reshape
(
batch_size
,
self
.
top_k
,
-
1
)
# print('Fresh block logits shape: ', fresh_block_logits.shape, flush=True)
# [batch_size x embed_size x 1]
query_logits
=
mpu
.
checkpoint
(
true_model
.
embed_query
,
tokens
,
attention_mask
).
unsqueeze
(
2
)
# print('Query logits shape: ', query_logits.shape, flush=True)
# [batch_size x k]
fresh_block_scores
=
torch
.
matmul
(
fresh_block_logits
,
query_logits
).
squeeze
()
# print('Block score shape: ', fresh_block_scores.shape, flush=True)
block_probs
=
F
.
softmax
(
fresh_block_scores
,
dim
=
1
)
# [batch_size * k x seq_length]
tokens
=
torch
.
stack
([
tokens
.
unsqueeze
(
1
)]
*
self
.
top_k
,
dim
=
1
).
reshape
(
-
1
,
seq_length
)
#assert all(tokens[i] == tokens[0] for i in range(self.top_k))
#assert all(tokens[i] == tokens[self.top_k] for i in range(self.top_k, 2 * self.top_k))
#assert not any(tokens[i] == tokens[0] for i in range(self.top_k, batch_size * self.top_k))
attention_mask
=
torch
.
stack
([
attention_mask
.
unsqueeze
(
1
)]
*
self
.
top_k
,
dim
=
1
).
reshape
(
-
1
,
seq_length
)
# [batch_size * k x 2 * seq_length]
all_tokens
=
torch
.
cat
((
tokens
,
topk_block_tokens
),
axis
=
1
)
all_attention_mask
=
torch
.
cat
((
attention_mask
,
topk_block_attention_mask
),
axis
=
1
)
all_token_types
=
torch
.
zeros
(
all_tokens
.
shape
).
type
(
torch
.
int64
).
cuda
()
lm_input_batch_shape
=
(
batch_size
*
self
.
top_k
,
2
*
seq_length
)
all_tokens
=
torch
.
zeros
(
lm_input_batch_shape
).
long
().
cuda
()
all_attention_mask
=
all_tokens
.
clone
()
all_token_types
=
all_tokens
.
clone
()
#all_tokens = torch.cat((tokens, topk_block_tokens), axis=1)
#all_attention_mask = torch.cat((attention_mask, topk_block_attention_mask), axis=1)
#all_token_types = torch.zeros(all_tokens.shape).type(torch.int64).cuda()
# re-align tokens to be contiguous
query_lengths
=
torch
.
sum
(
attention_mask
,
axis
=
1
)
block_lengths
=
torch
.
sum
(
topk_block_attention_mask
,
axis
=
1
)
for
row_num
in
range
(
all_tokens
.
shape
[
0
]):
qlen
=
query_lengths
[
row_num
]
blen
=
block_lengths
[
row_num
]
# disregard the CLS token from the block tokens
new_tokens_length
=
qlen
+
blen
-
1
# all blocks (including null ones) will have two SEP tokens
block_sep_indices
=
(
topk_block_tokens
==
dset
.
sep_id
).
nonzero
().
reshape
(
batch_size
*
self
.
top_k
,
2
,
2
)
# block body starts after the first SEP
block_starts
=
block_sep_indices
[:,
0
,
1
]
+
1
# block body ends after the second SEP
block_ends
=
block_sep_indices
[:,
1
,
1
]
+
1
all_tokens
[
row_num
,
:
qlen
]
=
tokens
[
row_num
,
:
qlen
]
all_tokens
[
row_num
,
qlen
:
new_tokens_length
]
=
tokens
[
row_num
,
1
:
blen
]
# block_lengths = torch.sum(topk_block_attention_mask, axis=1)
for
row_num
in
range
(
all_tokens
.
shape
[
0
]):
q_len
=
query_lengths
[
row_num
]
b_start
=
block_starts
[
row_num
]
b_end
=
block_ends
[
row_num
]
# new tokens = CLS + query + SEP + block + SEP
new_tokens_length
=
q_len
+
b_end
-
b_start
# splice query and block tokens accordingly
all_tokens
[
row_num
,
:
q_len
]
=
tokens
[
row_num
,
:
q_len
]
all_tokens
[
row_num
,
q_len
:
new_tokens_length
]
=
topk_block_tokens
[
row_num
,
b_start
:
b_end
]
all_tokens
[
row_num
,
new_tokens_length
:]
=
self
.
retriever
.
ict_dataset
.
pad_id
# print(dset.decode_tokens(detach(all_tokens[row_num]).tolist()), '\n', flush=True)
all_attention_mask
[
row_num
,
:
new_tokens_length
]
=
1
all_attention_mask
[
row_num
,
new_tokens_length
:]
=
0
...
...
megatron/mpu/initialize.py
View file @
2fd4ea6c
...
...
@@ -120,7 +120,6 @@ def set_data_parallel_group(group):
global
_DATA_PARALLEL_GROUP
assert
_DATA_PARALLEL_GROUP
is
None
,
\
'data parallel group has already been initialized'
print
(
">>> setting data parallel group: "
,
group
,
flush
=
True
)
_DATA_PARALLEL_GROUP
=
group
...
...
megatron/training.py
View file @
2fd4ea6c
...
...
@@ -18,6 +18,7 @@
from
datetime
import
datetime
import
math
import
sys
import
time
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
...
...
@@ -381,8 +382,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
group
=
get_gloo_comm_group
(),
async_op
=
True
)
last_reload_iteration
=
iteration
while
iteration
<
args
.
train_iters
:
if
iteration
>=
last_reload_iteration
+
500
and
not
recv_handle
.
is_completed
():
time
.
sleep
(
5
)
continue
# this only applies for realm right here
if
args
.
max_training_rank
is
not
None
and
recv_handle
.
is_completed
()
and
iteration
>=
last_reload_iteration
+
500
:
if
args
.
max_training_rank
is
not
None
and
recv_handle
.
is_completed
():
# should add check that INDEX_READY == 1 but what else could be happening
true_model
=
model
...
...
@@ -393,7 +398,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
print
(
"> Saving model and reloading index"
,
flush
=
True
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
if
args
.
rank
==
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
true_model
.
retriever
.
reload_index
()
if
args
.
rank
==
0
:
...
...
pretrain_realm.py
View file @
2fd4ea6c
...
...
@@ -49,7 +49,7 @@ def model_provider():
hashed_index
.
add_block_embed_data
(
all_block_data
)
# top_k + 1 because we may need to exclude trivial candidate
retriever
=
REALMRetriever
(
ict_model
,
ict_dataset
,
all_block_data
,
hashed_index
,
args
.
block_top_k
+
1
)
retriever
=
REALMRetriever
(
ict_model
,
ict_dataset
,
all_block_data
,
hashed_index
,
args
.
block_top_k
)
model
=
REALMBertModel
(
retriever
)
return
model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment