Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
51204a4d
Commit
51204a4d
authored
May 30, 2020
by
Neel Kant
Browse files
Misc changes
parent
dfb907fe
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
43 additions
and
52 deletions
+43
-52
indexer.py
indexer.py
+0
-5
megatron/data/realm_dataset_utils.py
megatron/data/realm_dataset_utils.py
+2
-1
megatron/data/realm_index.py
megatron/data/realm_index.py
+5
-6
megatron/global_vars.py
megatron/global_vars.py
+2
-2
megatron/model/distributed.py
megatron/model/distributed.py
+1
-1
megatron/model/realm_model.py
megatron/model/realm_model.py
+12
-13
megatron/training.py
megatron/training.py
+21
-24
No files found.
indexer.py
View file @
51204a4d
...
...
@@ -167,10 +167,6 @@ class AsyncIndexBuilder(IndexBuilder):
print
(
"Starting (again!)"
,
flush
=
True
)
self
.
build_and_save_index
()
self
.
send_index_ready_signal
()
while
INDEX_READY
==
1
:
print
(
"Waiting for new model checkpoint."
,
flush
=
True
)
time
.
sleep
(
5
)
self
.
load_attributes
()
def
load_attributes
(
self
):
...
...
@@ -195,7 +191,6 @@ class AsyncIndexBuilder(IndexBuilder):
# recv handle
dist
.
broadcast
(
INDEX_READY
,
0
,
group
=
get_gloo_comm_group
())
torch
.
distributed
.
barrier
(
get_data_parallel_group
())
def
load_ict_checkpoint
(
only_query_model
=
False
,
only_block_model
=
False
,
no_grad
=
False
,
from_realm_chkpt
=
False
):
...
...
megatron/data/realm_dataset_utils.py
View file @
51204a4d
...
...
@@ -96,7 +96,8 @@ def salient_span_mask(tokens, mask_id):
# need to get all named entities
entities
=
SPACY_NER
(
tokens_str
).
ents
entities
=
[
e
for
e
in
entities
if
e
.
text
!=
"CLS"
]
undesired_types
=
[
'CARDINAL'
,
'TIME'
,
'PERCENT'
,
'MONEY'
,
'QUANTITY'
,
'ORDINAL'
]
entities
=
[
e
for
e
in
entities
if
e
.
text
!=
"CLS"
and
e
.
label_
not
in
undesired_types
]
if
len
(
entities
)
==
0
:
return
None
entity_idx
=
np
.
random
.
randint
(
0
,
len
(
entities
))
...
...
megatron/data/realm_index.py
View file @
51204a4d
...
...
@@ -29,7 +29,7 @@ class BlockData(object):
def
clear
(
self
):
"""Clear the data structures to save memory"""
self
.
embed_data
=
dict
()
self
.
meta_data
=
dict
()
#
self.meta_data = dict()
@
classmethod
def
load_from_file
(
cls
,
fname
):
...
...
@@ -100,7 +100,7 @@ class FaissMIPSIndex(object):
self
.
block_mips_index
=
faiss
.
index_factory
(
self
.
embed_size
,
'Flat'
,
faiss
.
METRIC_INNER_PRODUCT
)
if
not
self
.
use_gpu
:
self
.
block_mips_index
=
faiss
.
IndexIDMap
(
self
.
block_mips_index
)
print
(
">> Finished building index"
,
flush
=
True
)
print
(
">> Finished building index
\n
"
,
flush
=
True
)
if
self
.
use_gpu
:
res
=
faiss
.
StandardGpuResources
()
...
...
@@ -109,9 +109,10 @@ class FaissMIPSIndex(object):
config
.
device
=
torch
.
cuda
.
current_device
()
config
.
useFloat16
=
True
self
.
block_mips_index
=
faiss
.
GpuIndexFlat
(
res
,
self
.
block_mips_index
,
config
)
print
(
">>>
Loaded Faiss
index on GPU {}
\n
"
.
format
(
self
.
block_mips_index
.
getDevice
()),
flush
=
True
)
print
(
">>>
Finished building
index on GPU {}
\n
"
.
format
(
self
.
block_mips_index
.
getDevice
()),
flush
=
True
)
def
reset_index
(
self
):
del
self
.
block_mips_index
self
.
_set_block_index
()
def
add_block_embed_data
(
self
,
all_block_data
,
clear_block_data
=
False
):
...
...
@@ -120,7 +121,7 @@ class FaissMIPSIndex(object):
if
self
.
use_gpu
:
for
i
,
idx
in
enumerate
(
block_indices
):
self
.
id_map
[
i
]
=
idx
if
clear_block_data
:
if
True
:
all_block_data
.
clear
()
if
self
.
use_gpu
:
...
...
@@ -134,8 +135,6 @@ class FaissMIPSIndex(object):
:param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks
if False: return [num_queries x k] array of distances, and another for indices
"""
if
self
.
index_type
==
'flat_l2'
:
query_embeds
=
self
.
alsh_query_preprocess_fn
(
query_embeds
)
query_embeds
=
np
.
float32
(
detach
(
query_embeds
))
# query_embeds = query_embeds.float()
...
...
megatron/global_vars.py
View file @
51204a4d
...
...
@@ -164,14 +164,14 @@ class _Timer:
def
start
(
self
):
"""Start the timer."""
assert
not
self
.
started_
,
'timer has already been started'
#
torch.cuda.synchronize()
torch
.
cuda
.
synchronize
()
self
.
start_time
=
time
.
time
()
self
.
started_
=
True
def
stop
(
self
):
"""Stop the timer."""
assert
self
.
started_
,
'timer is not started'
#
torch.cuda.synchronize()
torch
.
cuda
.
synchronize
()
self
.
elapsed_
+=
(
time
.
time
()
-
self
.
start_time
)
self
.
started_
=
False
...
...
megatron/model/distributed.py
View file @
51204a4d
...
...
@@ -56,7 +56,7 @@ class DistributedDataParallel(MegatronModule):
if
not
no_scale
and
not
reduce_after
:
coalesced
/=
dist
.
get_world_size
(
group
=
self
.
data_parallel_group
)
dist
.
all_reduce
(
coalesced
,
group
=
self
.
data_parallel_group
)
#
torch.cuda.synchronize()
torch
.
cuda
.
synchronize
()
if
not
no_scale
and
reduce_after
:
coalesced
/=
dist
.
get_world_size
(
group
=
self
.
data_parallel_group
)
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
...
...
megatron/model/realm_model.py
View file @
51204a4d
...
...
@@ -103,11 +103,11 @@ class REALMBertModel(MegatronModule):
# print("\nAttention: ", det_attention, '\n', flush=True)
# print("pad id: ", dset.pad_id, flush=True)
assert
bool
(
0
in
det_attention
)
==
bool
(
dset
.
pad_id
in
det_tokens
)
if
0
in
det_attention
:
idx_padid
=
det_tokens
.
index
(
dset
.
pad_id
)
idx_attn
=
det_attention
.
index
(
0
)
assert
idx_padid
==
idx_attn
,
(
idx_padid
,
idx_attn
)
#
assert bool(0 in det_attention) == bool(dset.pad_id in det_tokens)
#
if 0 in det_attention:
#
idx_padid = det_tokens.index(dset.pad_id)
#
idx_attn = det_attention.index(0)
#
assert idx_padid == idx_attn, (idx_padid, idx_attn)
# text = dset.decode_tokens(det_tokens)
# print(text, flush=True)
...
...
@@ -135,12 +135,12 @@ class REALMBertModel(MegatronModule):
fresh_block_logits
=
fresh_block_logits
.
reshape
(
batch_size
,
self
.
top_k
,
-
1
)
# print('Fresh block logits shape: ', fresh_block_logits.shape, flush=True)
# [batch_size x embed_size
x 1
]
query_logits
=
mpu
.
checkpoint
(
true_model
.
embed_query
,
tokens
,
attention_mask
).
unsqueeze
(
2
)
# [batch_size x
1 x
embed_size]
query_logits
=
mpu
.
checkpoint
(
true_model
.
embed_query
,
tokens
,
attention_mask
).
unsqueeze
(
1
)
# print('Query logits shape: ', query_logits.shape, flush=True)
# [batch_size x k]
fresh_block_scores
=
torch
.
matmul
(
fresh_block_logits
,
query_logits
).
squeeze
()
fresh_block_scores
=
torch
.
matmul
(
query_logits
,
torch
.
transpose
(
fresh_block_logits
,
1
,
2
)
).
squeeze
()
# print('Block score shape: ', fresh_block_scores.shape, flush=True)
block_probs
=
F
.
softmax
(
fresh_block_scores
,
dim
=
1
)
...
...
@@ -175,11 +175,11 @@ class REALMBertModel(MegatronModule):
b_start
=
block_starts
[
row_num
]
b_end
=
block_ends
[
row_num
]
# new tokens = CLS + query + SEP + block + SEP
new_tokens_length
=
q_len
+
b_end
-
b_start
#
new_tokens_length = q_len + b_end - b_start
new_tokens_length
=
q_len
# splice query and block tokens accordingly
all_tokens
[
row_num
,
:
q_len
]
=
tokens
[
row_num
,
:
q_len
]
all_tokens
[
row_num
,
q_len
:
new_tokens_length
]
=
topk_block_tokens
[
row_num
,
b_start
:
b_end
]
#
all_tokens[row_num, q_len:new_tokens_length] = topk_block_tokens[row_num, b_start:b_end]
all_tokens
[
row_num
,
new_tokens_length
:]
=
self
.
retriever
.
ict_dataset
.
pad_id
# print(dset.decode_tokens(detach(all_tokens[row_num]).tolist()), '\n', flush=True)
...
...
@@ -226,9 +226,8 @@ class REALMRetriever(MegatronModule):
def
reload_index
(
self
):
args
=
get_args
()
self
.
block_data
=
BlockData
.
load_from_file
(
args
.
block_data_path
)
print
(
"resetting index"
,
flush
=
True
)
self
.
hashed_index
.
reset_index
()
self
.
block_data
=
BlockData
.
load_from_file
(
args
.
block_data_path
)
self
.
hashed_index
.
add_block_embed_data
(
self
.
block_data
)
def
prep_query_text_for_retrieval
(
self
,
query_text
):
...
...
megatron/training.py
View file @
51204a4d
...
...
@@ -244,7 +244,7 @@ def backward_step(optimizer, model, loss):
"""Backward step."""
args
=
get_args
()
timers
=
get_timers
()
#
torch.cuda.synchronize()
torch
.
cuda
.
synchronize
()
# Backward pass.
# optimizer.zero_grad(set_grads_to_None=True)
...
...
@@ -392,39 +392,36 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
group
=
get_gloo_comm_group
(),
async_op
=
True
)
last_reload_iteration
=
iteration
while
iteration
<
args
.
train_iters
:
if
args
.
max_training_rank
is
not
None
and
iteration
>=
last_reload_iteration
+
500
and
not
recv_handle
.
is_completed
():
time
.
sleep
(
5
)
continue
# this only applies for realm right here
if
args
.
max_training_rank
is
not
None
and
recv_handle
.
is_completed
():
# should add check that INDEX_READY == 1 but what else could be happening
true_model
=
model
if
hasattr
(
true_model
,
'module'
):
true_model
=
true_model
.
module
if
args
.
max_training_rank
is
not
None
and
iteration
>=
last_reload_iteration
+
500
:
if
recv_handle
.
is_completed
():
# should add check that INDEX_READY == 1 but what else could be happening
true_model
=
model
if
hasattr
(
true_model
,
'module'
):
true_model
=
true_model
.
module
if
hasattr
(
true_model
,
'module'
):
true_model
=
true_model
.
module
print
(
"> Saving model and reloading index"
,
flush
=
True
)
if
args
.
rank
==
0
:
print
(
"> Saving model and reloading index"
,
flush
=
True
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
true_model
.
retriever
.
reload_index
()
if
args
.
rank
==
0
:
INDEX_READY
=
1
-
INDEX_READY
# send handle
torch
.
distributed
.
broadcast
(
INDEX_READY
,
0
,
group
=
get_gloo_comm_group
())
true_model
.
retriever
.
reload_index
()
if
args
.
rank
==
0
:
INDEX_READY
=
1
-
INDEX_READY
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
group
=
get_gloo_comm_group
(),
async_op
=
True
)
last_reload_iteration
=
iteration
else
:
time
.
sleep
(
5
)
continue
# send handle
torch
.
distributed
.
broadcast
(
INDEX_READY
,
0
,
group
=
get_gloo_comm_group
())
torch
.
distributed
.
barrier
(
get_data_parallel_group
())
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
group
=
get_gloo_comm_group
(),
async_op
=
True
)
last_reload_iteration
=
iteration
elif
iteration
<
20
:
print
(
"moving right along"
,
flush
=
True
)
#
report_memory("iteration {}".format(iteration))
report_memory
(
"iteration {}"
.
format
(
iteration
))
loss_dict
,
skipped_iter
=
train_step
(
forward_step_func
,
train_data_iterator
,
model
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment