Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
2fd4ea6c
"...text-generation-inference.git" did not exist on "8b8e8f663286b8a5f4d5c5fb14cd02a60cb43652"
Commit
2fd4ea6c
authored
May 25, 2020
by
Neel Kant
Browse files
Corrected realm example building, misc improvements for async concurrency
parent
8e22824e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
73 additions
and
21 deletions
+73
-21
indexer.py
indexer.py
+2
-1
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+3
-2
megatron/model/realm_model.py
megatron/model/realm_model.py
+59
-14
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+0
-1
megatron/training.py
megatron/training.py
+8
-2
pretrain_realm.py
pretrain_realm.py
+1
-1
No files found.
indexer.py
View file @
2fd4ea6c
...
@@ -95,6 +95,7 @@ def setup_realm_groups_and_vars():
...
@@ -95,6 +95,7 @@ def setup_realm_groups_and_vars():
class
IndexBuilder
(
object
):
class
IndexBuilder
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
args
=
get_args
()
args
=
get_args
()
self
.
debug
=
args
.
debug
self
.
rank
=
args
.
rank
self
.
rank
=
args
.
rank
self
.
model
=
None
self
.
model
=
None
self
.
dataloader
=
None
self
.
dataloader
=
None
...
@@ -287,6 +288,6 @@ def get_one_epoch_dataloader(dataset, batch_size=None):
...
@@ -287,6 +288,6 @@ def get_one_epoch_dataloader(dataset, batch_size=None):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
initialize_megatron
(
extra_args_provider
=
None
,
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
index_builder
=
Basic
IndexBuilder
()
index_builder
=
IndexBuilder
()
index_builder
.
build_and_save_index
()
index_builder
.
build_and_save_index
()
megatron/data/realm_dataset.py
View file @
2fd4ea6c
...
@@ -5,7 +5,7 @@ import numpy as np
...
@@ -5,7 +5,7 @@ import numpy as np
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron.data.realm_dataset_utils
import
build_realm_training_sample
,
get_block_samples_mapping
from
megatron.data.realm_dataset_utils
import
build_realm_training_sample
,
get_block_samples_mapping
,
join_str_list
class
REALMDataset
(
Dataset
):
class
REALMDataset
(
Dataset
):
...
@@ -136,7 +136,8 @@ class ICTDataset(Dataset):
...
@@ -136,7 +136,8 @@ class ICTDataset(Dataset):
def
decode_tokens
(
self
,
token_ids
):
def
decode_tokens
(
self
,
token_ids
):
tokens
=
self
.
tokenizer
.
tokenizer
.
convert_ids_to_tokens
(
token_ids
)
tokens
=
self
.
tokenizer
.
tokenizer
.
convert_ids_to_tokens
(
token_ids
)
return
' '
.
join
(
token
for
token
in
tokens
if
token
!=
'[PAD]'
)
non_pads
=
[
t
for
t
in
tokens
if
t
!=
'[PAD]'
]
return
join_str_list
(
non_pads
)
def
get_block
(
self
,
start_idx
,
end_idx
,
doc_idx
):
def
get_block
(
self
,
start_idx
,
end_idx
,
doc_idx
):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
"""Get the IDs for an evidence block plus the title of the corresponding document"""
...
...
megatron/model/realm_model.py
View file @
2fd4ea6c
...
@@ -94,51 +94,96 @@ class REALMBertModel(MegatronModule):
...
@@ -94,51 +94,96 @@ class REALMBertModel(MegatronModule):
self
.
_retriever_key
=
'retriever'
self
.
_retriever_key
=
'retriever'
def
forward
(
self
,
tokens
,
attention_mask
,
query_block_indices
,
return_topk_block_tokens
=
False
):
def
forward
(
self
,
tokens
,
attention_mask
,
query_block_indices
,
return_topk_block_tokens
=
False
):
# print("\nNEW FORWARD", '-' * 100, flush=True)
dset
=
self
.
retriever
.
ict_dataset
det_tokens
=
detach
(
tokens
)[
0
].
tolist
()
det_attention
=
detach
(
attention_mask
)[
0
].
tolist
()
# print("\nTokens: ", det_tokens, '\n', flush=True)
# print("\nAttention: ", det_attention, '\n', flush=True)
# print("pad id: ", dset.pad_id, flush=True)
assert
bool
(
0
in
det_attention
)
==
bool
(
dset
.
pad_id
in
det_tokens
)
if
0
in
det_attention
:
idx_padid
=
det_tokens
.
index
(
dset
.
pad_id
)
idx_attn
=
det_attention
.
index
(
0
)
assert
idx_padid
==
idx_attn
,
(
idx_padid
,
idx_attn
)
# text = dset.decode_tokens(det_tokens)
# print(text, flush=True)
# print("Token shape: ", tokens.shape, flush=True)
# [batch_size x k x seq_length]
# [batch_size x k x seq_length]
topk_block_tokens
,
topk_block_attention_mask
=
self
.
retriever
.
retrieve_evidence_blocks
(
topk_block_tokens
,
topk_block_attention_mask
=
self
.
retriever
.
retrieve_evidence_blocks
(
tokens
,
attention_mask
,
query_block_indices
=
query_block_indices
,
include_null_doc
=
True
)
tokens
,
attention_mask
,
query_block_indices
=
query_block_indices
,
include_null_doc
=
True
)
# print("Top k block shape: ", topk_block_tokens.shape, flush=True)
batch_size
=
tokens
.
shape
[
0
]
batch_size
=
tokens
.
shape
[
0
]
# create a copy in case it needs to be returned
# create a copy in case it needs to be returned
ret_topk_block_tokens
=
np
.
array
(
topk_block_tokens
)
ret_topk_block_tokens
=
np
.
array
(
topk_block_tokens
)
seq_length
=
topk_block_tokens
.
shape
[
2
]
seq_length
=
topk_block_tokens
.
shape
[
2
]
topk_block_tokens
=
torch
.
cuda
.
LongTensor
(
topk_block_tokens
).
reshape
(
-
1
,
seq_length
)
long_tensor
=
torch
.
cuda
.
LongTensor
topk_block_attention_mask
=
torch
.
cuda
.
LongTensor
(
topk_block_attention_mask
).
reshape
(
-
1
,
seq_length
)
topk_block_tokens
=
long_tensor
(
topk_block_tokens
).
reshape
(
-
1
,
seq_length
)
topk_block_attention_mask
=
long_tensor
(
topk_block_attention_mask
).
reshape
(
-
1
,
seq_length
)
# print('Block token shape: ', topk_block_tokens.shape, flush=True)
# [batch_size x k x embed_size]
# [batch_size x k x embed_size]
true_model
=
self
.
retriever
.
ict_model
.
module
.
module
true_model
=
self
.
retriever
.
ict_model
.
module
.
module
fresh_block_logits
=
mpu
.
checkpoint
(
true_model
.
embed_block
,
topk_block_tokens
,
topk_block_attention_mask
)
fresh_block_logits
=
mpu
.
checkpoint
(
true_model
.
embed_block
,
topk_block_tokens
,
topk_block_attention_mask
)
fresh_block_logits
=
fresh_block_logits
.
reshape
(
batch_size
,
self
.
top_k
,
-
1
)
fresh_block_logits
=
fresh_block_logits
.
reshape
(
batch_size
,
self
.
top_k
,
-
1
)
# print('Fresh block logits shape: ', fresh_block_logits.shape, flush=True)
# [batch_size x embed_size x 1]
# [batch_size x embed_size x 1]
query_logits
=
mpu
.
checkpoint
(
true_model
.
embed_query
,
tokens
,
attention_mask
).
unsqueeze
(
2
)
query_logits
=
mpu
.
checkpoint
(
true_model
.
embed_query
,
tokens
,
attention_mask
).
unsqueeze
(
2
)
# print('Query logits shape: ', query_logits.shape, flush=True)
# [batch_size x k]
# [batch_size x k]
fresh_block_scores
=
torch
.
matmul
(
fresh_block_logits
,
query_logits
).
squeeze
()
fresh_block_scores
=
torch
.
matmul
(
fresh_block_logits
,
query_logits
).
squeeze
()
# print('Block score shape: ', fresh_block_scores.shape, flush=True)
block_probs
=
F
.
softmax
(
fresh_block_scores
,
dim
=
1
)
block_probs
=
F
.
softmax
(
fresh_block_scores
,
dim
=
1
)
# [batch_size * k x seq_length]
# [batch_size * k x seq_length]
tokens
=
torch
.
stack
([
tokens
.
unsqueeze
(
1
)]
*
self
.
top_k
,
dim
=
1
).
reshape
(
-
1
,
seq_length
)
tokens
=
torch
.
stack
([
tokens
.
unsqueeze
(
1
)]
*
self
.
top_k
,
dim
=
1
).
reshape
(
-
1
,
seq_length
)
#assert all(tokens[i] == tokens[0] for i in range(self.top_k))
#assert all(tokens[i] == tokens[self.top_k] for i in range(self.top_k, 2 * self.top_k))
#assert not any(tokens[i] == tokens[0] for i in range(self.top_k, batch_size * self.top_k))
attention_mask
=
torch
.
stack
([
attention_mask
.
unsqueeze
(
1
)]
*
self
.
top_k
,
dim
=
1
).
reshape
(
-
1
,
seq_length
)
attention_mask
=
torch
.
stack
([
attention_mask
.
unsqueeze
(
1
)]
*
self
.
top_k
,
dim
=
1
).
reshape
(
-
1
,
seq_length
)
# [batch_size * k x 2 * seq_length]
# [batch_size * k x 2 * seq_length]
all_tokens
=
torch
.
cat
((
tokens
,
topk_block_tokens
),
axis
=
1
)
lm_input_batch_shape
=
(
batch_size
*
self
.
top_k
,
2
*
seq_length
)
all_attention_mask
=
torch
.
cat
((
attention_mask
,
topk_block_attention_mask
),
axis
=
1
)
all_tokens
=
torch
.
zeros
(
lm_input_batch_shape
).
long
().
cuda
()
all_token_types
=
torch
.
zeros
(
all_tokens
.
shape
).
type
(
torch
.
int64
).
cuda
()
all_attention_mask
=
all_tokens
.
clone
()
all_token_types
=
all_tokens
.
clone
()
#all_tokens = torch.cat((tokens, topk_block_tokens), axis=1)
#all_attention_mask = torch.cat((attention_mask, topk_block_attention_mask), axis=1)
#all_token_types = torch.zeros(all_tokens.shape).type(torch.int64).cuda()
# re-align tokens to be contiguous
query_lengths
=
torch
.
sum
(
attention_mask
,
axis
=
1
)
query_lengths
=
torch
.
sum
(
attention_mask
,
axis
=
1
)
block_lengths
=
torch
.
sum
(
topk_block_attention_mask
,
axis
=
1
)
# all blocks (including null ones) will have two SEP tokens
for
row_num
in
range
(
all_tokens
.
shape
[
0
]):
block_sep_indices
=
(
topk_block_tokens
==
dset
.
sep_id
).
nonzero
().
reshape
(
batch_size
*
self
.
top_k
,
2
,
2
)
qlen
=
query_lengths
[
row_num
]
blen
=
block_lengths
[
row_num
]
# block body starts after the first SEP
# disregard the CLS token from the block tokens
block_starts
=
block_sep_indices
[:,
0
,
1
]
+
1
new_tokens_length
=
qlen
+
blen
-
1
# block body ends after the second SEP
block_ends
=
block_sep_indices
[:,
1
,
1
]
+
1
all_tokens
[
row_num
,
:
qlen
]
=
tokens
[
row_num
,
:
qlen
]
# block_lengths = torch.sum(topk_block_attention_mask, axis=1)
all_tokens
[
row_num
,
qlen
:
new_tokens_length
]
=
tokens
[
row_num
,
1
:
blen
]
for
row_num
in
range
(
all_tokens
.
shape
[
0
]):
q_len
=
query_lengths
[
row_num
]
b_start
=
block_starts
[
row_num
]
b_end
=
block_ends
[
row_num
]
# new tokens = CLS + query + SEP + block + SEP
new_tokens_length
=
q_len
+
b_end
-
b_start
# splice query and block tokens accordingly
all_tokens
[
row_num
,
:
q_len
]
=
tokens
[
row_num
,
:
q_len
]
all_tokens
[
row_num
,
q_len
:
new_tokens_length
]
=
topk_block_tokens
[
row_num
,
b_start
:
b_end
]
all_tokens
[
row_num
,
new_tokens_length
:]
=
self
.
retriever
.
ict_dataset
.
pad_id
all_tokens
[
row_num
,
new_tokens_length
:]
=
self
.
retriever
.
ict_dataset
.
pad_id
# print(dset.decode_tokens(detach(all_tokens[row_num]).tolist()), '\n', flush=True)
all_attention_mask
[
row_num
,
:
new_tokens_length
]
=
1
all_attention_mask
[
row_num
,
:
new_tokens_length
]
=
1
all_attention_mask
[
row_num
,
new_tokens_length
:]
=
0
all_attention_mask
[
row_num
,
new_tokens_length
:]
=
0
...
...
megatron/mpu/initialize.py
View file @
2fd4ea6c
...
@@ -120,7 +120,6 @@ def set_data_parallel_group(group):
...
@@ -120,7 +120,6 @@ def set_data_parallel_group(group):
global
_DATA_PARALLEL_GROUP
global
_DATA_PARALLEL_GROUP
assert
_DATA_PARALLEL_GROUP
is
None
,
\
assert
_DATA_PARALLEL_GROUP
is
None
,
\
'data parallel group has already been initialized'
'data parallel group has already been initialized'
print
(
">>> setting data parallel group: "
,
group
,
flush
=
True
)
_DATA_PARALLEL_GROUP
=
group
_DATA_PARALLEL_GROUP
=
group
...
...
megatron/training.py
View file @
2fd4ea6c
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
from
datetime
import
datetime
from
datetime
import
datetime
import
math
import
math
import
sys
import
sys
import
time
import
torch
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
...
@@ -381,8 +382,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -381,8 +382,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
group
=
get_gloo_comm_group
(),
async_op
=
True
)
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
group
=
get_gloo_comm_group
(),
async_op
=
True
)
last_reload_iteration
=
iteration
last_reload_iteration
=
iteration
while
iteration
<
args
.
train_iters
:
while
iteration
<
args
.
train_iters
:
if
iteration
>=
last_reload_iteration
+
500
and
not
recv_handle
.
is_completed
():
time
.
sleep
(
5
)
continue
# this only applies for realm right here
# this only applies for realm right here
if
args
.
max_training_rank
is
not
None
and
recv_handle
.
is_completed
()
and
iteration
>=
last_reload_iteration
+
500
:
if
args
.
max_training_rank
is
not
None
and
recv_handle
.
is_completed
():
# should add check that INDEX_READY == 1 but what else could be happening
# should add check that INDEX_READY == 1 but what else could be happening
true_model
=
model
true_model
=
model
...
@@ -393,7 +398,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -393,7 +398,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
print
(
"> Saving model and reloading index"
,
flush
=
True
)
print
(
"> Saving model and reloading index"
,
flush
=
True
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
if
args
.
rank
==
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
true_model
.
retriever
.
reload_index
()
true_model
.
retriever
.
reload_index
()
if
args
.
rank
==
0
:
if
args
.
rank
==
0
:
...
...
pretrain_realm.py
View file @
2fd4ea6c
...
@@ -49,7 +49,7 @@ def model_provider():
...
@@ -49,7 +49,7 @@ def model_provider():
hashed_index
.
add_block_embed_data
(
all_block_data
)
hashed_index
.
add_block_embed_data
(
all_block_data
)
# top_k + 1 because we may need to exclude trivial candidate
# top_k + 1 because we may need to exclude trivial candidate
retriever
=
REALMRetriever
(
ict_model
,
ict_dataset
,
all_block_data
,
hashed_index
,
args
.
block_top_k
+
1
)
retriever
=
REALMRetriever
(
ict_model
,
ict_dataset
,
all_block_data
,
hashed_index
,
args
.
block_top_k
)
model
=
REALMBertModel
(
retriever
)
model
=
REALMBertModel
(
retriever
)
return
model
return
model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment