Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
20895f2c
Commit
20895f2c
authored
Jun 04, 2020
by
Neel Kant
Browse files
runs with new log loss but plateaus early
parent
51204a4d
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
87 additions
and
31 deletions
+87
-31
indexer.py
indexer.py
+1
-1
megatron/data/realm_index.py
megatron/data/realm_index.py
+6
-0
megatron/model/realm_model.py
megatron/model/realm_model.py
+29
-18
megatron/training.py
megatron/training.py
+3
-1
pretrain_realm.py
pretrain_realm.py
+48
-11
No files found.
indexer.py
View file @
20895f2c
...
...
@@ -176,7 +176,7 @@ class AsyncIndexBuilder(IndexBuilder):
print
(
">>>>> No realm chkpt available"
,
flush
=
True
)
self
.
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
no_grad
=
True
,
from_realm_chkpt
=
False
)
self
.
model
.
eval
()
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
get_ict_dataset
()))
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
get_ict_dataset
()
,
batch_size
=
128
))
self
.
block_data
=
BlockData
()
def
send_index_ready_signal
(
self
):
...
...
megatron/data/realm_index.py
View file @
20895f2c
...
...
@@ -150,6 +150,12 @@ class FaissMIPSIndex(object):
for
j
in
range
(
block_indices
.
shape
[
1
]):
fresh_indices
[
i
,
j
]
=
self
.
id_map
[
block_indices
[
i
,
j
]]
block_indices
=
fresh_indices
args
=
get_args
()
if
args
.
rank
==
0
:
torch
.
save
({
'query_embeds'
:
query_embeds
,
'id_map'
:
self
.
id_map
,
'block_indices'
:
block_indices
,
'distances'
:
distances
},
'search.data'
)
return
distances
,
block_indices
# functions below are for ALSH, which currently isn't being used
...
...
megatron/model/realm_model.py
View file @
20895f2c
...
...
@@ -92,9 +92,9 @@ class REALMBertModel(MegatronModule):
self
.
retriever
=
retriever
self
.
top_k
=
self
.
retriever
.
top_k
self
.
_retriever_key
=
'retriever'
# self.eval()
def
forward
(
self
,
tokens
,
attention_mask
,
query_block_indices
,
return_topk_block_tokens
=
False
):
# print("\nNEW FORWARD", '-' * 100, flush=True)
dset
=
self
.
retriever
.
ict_dataset
det_tokens
=
detach
(
tokens
)[
0
].
tolist
()
...
...
@@ -112,7 +112,6 @@ class REALMBertModel(MegatronModule):
# text = dset.decode_tokens(det_tokens)
# print(text, flush=True)
# print("Token shape: ", tokens.shape, flush=True)
# [batch_size x k x seq_length]
topk_block_tokens
,
topk_block_attention_mask
=
self
.
retriever
.
retrieve_evidence_blocks
(
...
...
@@ -132,23 +131,21 @@ class REALMBertModel(MegatronModule):
# [batch_size x k x embed_size]
true_model
=
self
.
retriever
.
ict_model
.
module
.
module
fresh_block_logits
=
mpu
.
checkpoint
(
true_model
.
embed_block
,
topk_block_tokens
,
topk_block_attention_mask
)
fresh_block_logits
=
fresh_block_logits
.
reshape
(
batch_size
,
self
.
top_k
,
-
1
)
fresh_block_logits
=
fresh_block_logits
.
reshape
(
batch_size
,
self
.
top_k
,
-
1
)
.
float
()
# print('Fresh block logits shape: ', fresh_block_logits.shape, flush=True)
# [batch_size x 1 x embed_size]
query_logits
=
mpu
.
checkpoint
(
true_model
.
embed_query
,
tokens
,
attention_mask
).
unsqueeze
(
1
)
# print('Query logits shape: ', query_logits.shape, flush=True)
query_logits
=
mpu
.
checkpoint
(
true_model
.
embed_query
,
tokens
,
attention_mask
).
unsqueeze
(
1
).
float
()
# [batch_size x k]
fresh_block_scores
=
torch
.
matmul
(
query_logits
,
torch
.
transpose
(
fresh_block_logits
,
1
,
2
)).
squeeze
()
# print('Block score shape: ', fresh_block_scores.shape, flush=True)
block_probs
=
F
.
softmax
(
fresh_block_scores
,
dim
=
1
)
# [batch_size * k x seq_length]
tokens
=
torch
.
stack
([
tokens
.
unsqueeze
(
1
)]
*
self
.
top_k
,
dim
=
1
).
reshape
(
-
1
,
seq_length
)
#assert all(tokens[i]
==
tokens[0] for i in range(self.top_k))
#assert all(tokens[i]
==
tokens[self.top_k] for i in range(self.top_k, 2 * self.top_k))
#assert not any(tokens[i]
==
tokens[0] for i in range(self.top_k, batch_size * self.top_k))
#
assert all(
torch.equal(
tokens[i]
,
tokens[0]
)
for i in range(self.top_k))
#
assert all(
torch.equal(
tokens[i]
,
tokens[self.top_k]
)
for i in range(self.top_k, 2 * self.top_k))
#
assert not any(
torch.equal(
tokens[i]
,
tokens[0]
)
for i in range(self.top_k, batch_size * self.top_k))
attention_mask
=
torch
.
stack
([
attention_mask
.
unsqueeze
(
1
)]
*
self
.
top_k
,
dim
=
1
).
reshape
(
-
1
,
seq_length
)
# [batch_size * k x 2 * seq_length]
...
...
@@ -156,9 +153,6 @@ class REALMBertModel(MegatronModule):
all_tokens
=
torch
.
zeros
(
lm_input_batch_shape
).
long
().
cuda
()
all_attention_mask
=
all_tokens
.
clone
()
all_token_types
=
all_tokens
.
clone
()
#all_tokens = torch.cat((tokens, topk_block_tokens), axis=1)
#all_attention_mask = torch.cat((attention_mask, topk_block_attention_mask), axis=1)
#all_token_types = torch.zeros(all_tokens.shape).type(torch.int64).cuda()
query_lengths
=
torch
.
sum
(
attention_mask
,
axis
=
1
)
# all blocks (including null ones) will have two SEP tokens
...
...
@@ -169,23 +163,40 @@ class REALMBertModel(MegatronModule):
# block body ends after the second SEP
block_ends
=
block_sep_indices
[:,
1
,
1
]
+
1
# block_lengths = torch.sum(topk_block_attention_mask, axis=1
)
print
(
'-'
*
100
)
for
row_num
in
range
(
all_tokens
.
shape
[
0
]):
q_len
=
query_lengths
[
row_num
]
b_start
=
block_starts
[
row_num
]
b_end
=
block_ends
[
row_num
]
# new tokens = CLS + query + SEP + block + SEP
#
new_tokens_length = q_len + b_end - b_start
new_tokens_length
=
q_len
new_tokens_length
=
q_len
+
b_end
-
b_start
# splice query and block tokens accordingly
all_tokens
[
row_num
,
:
q_len
]
=
tokens
[
row_num
,
:
q_len
]
#
all_tokens[row_num, q_len:new_tokens_length] = topk_block_tokens[row_num, b_start:b_end]
all_tokens
[
row_num
,
q_len
:
new_tokens_length
]
=
topk_block_tokens
[
row_num
,
b_start
:
b_end
]
all_tokens
[
row_num
,
new_tokens_length
:]
=
self
.
retriever
.
ict_dataset
.
pad_id
#
print(dset.decode_tokens(detach(all_tokens[row_num]).tolist()), '\n', flush=True)
print
(
dset
.
decode_tokens
(
detach
(
all_tokens
[
row_num
]).
tolist
()),
'
\n
'
,
flush
=
True
)
all_attention_mask
[
row_num
,
:
new_tokens_length
]
=
1
all_attention_mask
[
row_num
,
new_tokens_length
:]
=
0
print
(
'-'
*
100
)
args
=
get_args
()
if
args
.
rank
==
0
:
torch
.
save
({
'lm_tokens'
:
all_tokens
,
'lm_attn_mask'
:
all_attention_mask
,
'query_tokens'
:
tokens
,
'query_attn_mask'
:
attention_mask
,
'query_logits'
:
query_logits
,
'block_tokens'
:
topk_block_tokens
,
'block_attn_mask'
:
topk_block_attention_mask
,
'block_logits'
:
fresh_block_logits
,
'block_probs'
:
block_probs
,
},
'final_lm_inputs.data'
)
# assert all(torch.equal(all_tokens[i], all_tokens[0]) for i in range(self.top_k))
# assert all(torch.equal(all_attention_mask[i], all_attention_mask[0]) for i in range(self.top_k))
# [batch_size x k x 2 * seq_length x vocab_size]
lm_logits
,
_
=
self
.
lm_model
.
forward
(
all_tokens
,
all_attention_mask
,
all_token_types
)
...
...
@@ -261,7 +272,7 @@ class REALMRetriever(MegatronModule):
true_model
=
self
.
ict_model
# print("true model: ", true_model, flush=True)
query_embeds
=
self
.
ict_model
(
query_tokens
,
query_pad_mask
,
None
,
None
,
only_query
=
True
)
query_embeds
=
true_model
.
embed_query
(
query_tokens
,
query_pad_mask
)
_
,
block_indices
=
self
.
hashed_index
.
search_mips_index
(
query_embeds
,
top_k
=
self
.
top_k
,
reconstruct
=
False
)
all_topk_tokens
,
all_topk_pad_masks
=
[],
[]
...
...
megatron/training.py
View file @
20895f2c
...
...
@@ -242,6 +242,8 @@ def setup_model_and_optimizer(model_provider_func):
def
backward_step
(
optimizer
,
model
,
loss
):
"""Backward step."""
# if args.rank == 0:
# torch.save(lick)
args
=
get_args
()
timers
=
get_timers
()
torch
.
cuda
.
synchronize
()
...
...
@@ -392,7 +394,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
group
=
get_gloo_comm_group
(),
async_op
=
True
)
last_reload_iteration
=
iteration
while
iteration
<
args
.
train_iters
:
if
args
.
max_training_rank
is
not
None
and
iteration
>=
last_reload_iteration
+
5
00
:
if
args
.
max_training_rank
is
not
None
and
iteration
>=
last_reload_iteration
+
1
00
:
if
recv_handle
.
is_completed
():
# should add check that INDEX_READY == 1 but what else could be happening
true_model
=
model
...
...
pretrain_realm.py
View file @
20895f2c
...
...
@@ -14,6 +14,9 @@
# limitations under the License.
"""Pretrain BERT for Inverse Cloze Task"""
import
sys
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
...
...
@@ -29,6 +32,7 @@ from megatron.training import pretrain
from
megatron.utils
import
reduce_losses
,
report_memory
from
megatron
import
mpu
from
indexer
import
initialize_and_run_async_megatron
from
megatron.mpu.initialize
import
get_data_parallel_group
num_batches
=
0
...
...
@@ -44,7 +48,6 @@ def model_provider():
ict_model
=
load_ict_checkpoint
(
from_realm_chkpt
=
False
)
ict_dataset
=
get_ict_dataset
(
use_titles
=
False
)
all_block_data
=
BlockData
.
load_from_file
(
args
.
block_data_path
)
# hashed_index = RandProjectionLSHIndex.load_from_file(args.block_index_path)
hashed_index
=
FaissMIPSIndex
(
index_type
=
'flat_ip'
,
embed_size
=
128
,
use_gpu
=
args
.
faiss_use_gpu
)
hashed_index
.
add_block_embed_data
(
all_block_data
)
...
...
@@ -66,8 +69,6 @@ def get_batch(data_iterator):
else
:
data
=
next
(
data_iterator
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
...
...
@@ -96,21 +97,57 @@ def forward_step(data_iterator, model):
# Forward model.
lm_logits
,
block_probs
=
model
(
tokens
,
pad_mask
,
query_block_indices
)
# print('logits shape: ', lm_logits.shape, flush=True)
# print('labels shape: ', labels.shape, flush=True)
with
torch
.
no_grad
():
max_retrieval_utility
,
top_retrieval_utility
,
avg_retrieval_utility
=
mpu
.
checkpoint
(
get_retrieval_utility
,
lm_logits
,
block_probs
,
labels
,
loss_mask
)
# P(y|x) = sum_z(P(y|z, x) * P(z|x))
null_block_probs
=
torch
.
mean
(
block_probs
[:,
block_probs
.
shape
[
1
]
-
1
])
block_probs
=
block_probs
.
unsqueeze
(
2
).
unsqueeze
(
3
).
expand_as
(
lm_logits
)
lm_logits
=
torch
.
sum
(
lm_logits
*
block_probs
,
dim
=
1
)[:,
:
labels
.
shape
[
1
]]
lm_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
contiguous
().
float
(),
labels
.
contiguous
())
lm_loss
=
torch
.
sum
(
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
# logits: [batch x top_k x 2 * seq_length x vocab_size]
# labels: [batch x seq_length]
relevant_logits
=
lm_logits
[:,
:,
:
labels
.
shape
[
1
]].
float
()
# if get_args().rank == 0:
# torch.save({'logits': relevant_logits.cpu(),
# 'block_probs': block_probs.cpu(),
# 'labels': labels.cpu(),
# 'loss_mask': loss_mask.cpu(),
# 'tokens': tokens.cpu(),
# 'pad_mask': pad_mask.cpu(),
# }, 'tensors.data')
# torch.load('gagaga')
block_probs
=
block_probs
.
unsqueeze
(
2
).
unsqueeze
(
3
).
expand_as
(
relevant_logits
)
# print(torch.sum(block_probs, dim=1), flush=True)
def
get_log_probs
(
logits
,
b_probs
):
max_logits
=
torch
.
max
(
logits
,
dim
=-
1
,
keepdim
=
True
)[
0
].
expand_as
(
logits
)
logits
=
logits
-
max_logits
softmaxed_logits
=
F
.
softmax
(
logits
,
dim
=-
1
)
marginalized_probs
=
torch
.
sum
(
softmaxed_logits
*
b_probs
,
dim
=
1
)
l_probs
=
torch
.
log
(
marginalized_probs
)
return
l_probs
log_probs
=
mpu
.
checkpoint
(
get_log_probs
,
relevant_logits
,
block_probs
)
def
get_loss
(
l_probs
,
labs
):
vocab_size
=
l_probs
.
shape
[
2
]
loss
=
torch
.
nn
.
NLLLoss
(
ignore_index
=-
1
)(
l_probs
.
reshape
(
-
1
,
vocab_size
),
labs
.
reshape
(
-
1
))
# loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
return
loss
.
float
()
lm_loss
=
mpu
.
checkpoint
(
get_loss
,
log_probs
,
labels
)
# marginalized_logits = torch.sum(relevant_logits * block_probs, dim=1)
# vocab_size = marginalized_logits.shape[2]
# lm_loss_ = torch.nn.CrossEntropyLoss()(marginalized_logits.reshape(-1, vocab_size), labels.reshape(-1))
# lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
reduced_loss
=
reduce_losses
([
lm_loss
,
max_retrieval_utility
,
top_retrieval_utility
,
avg_retrieval_utility
,
null_block_probs
])
# reduced_loss = reduce_losses([lm_loss])
# torch.cuda.synchronize()
return
lm_loss
,
{
'lm_loss'
:
reduced_loss
[
0
],
'max_ru'
:
reduced_loss
[
1
],
...
...
@@ -119,10 +156,10 @@ def forward_step(data_iterator, model):
'null_prob'
:
reduced_loss
[
4
]}
def
get_retrieval_utility
(
lm_logits
,
block_probs
,
labels
,
loss_mask
):
def
get_retrieval_utility
(
lm_logits
_
,
block_probs
,
labels
,
loss_mask
):
"""log P(y | z, x) - log P(y | null, x)"""
# [batch x seq_len x vocab_size]
lm_logits
=
lm_logits
[:,
:,
:
labels
.
shape
[
1
],
:]
lm_logits
=
lm_logits
_
[:,
:,
:
labels
.
shape
[
1
],
:]
#non_null_block_probs = block_probs[:, :-1]
#non_null_block_probs /= torch.sum(non_null_block_probs, axis=1, keepdim=True)
# non_null_block_probs = non_null_block_probsexpand_as(lm_logits[:, :-1, :, :])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment