Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
25293807
Commit
25293807
authored
May 18, 2021
by
Mostofa Patwary
Browse files
additional cleaning
parent
2eaf6c79
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
7 additions
and
69 deletions
+7
-69
megatron/arguments.py
megatron/arguments.py
+0
-6
megatron/checkpointing.py
megatron/checkpointing.py
+0
-1
megatron/indexer.py
megatron/indexer.py
+5
-34
megatron/learning_rates.py
megatron/learning_rates.py
+0
-16
tasks/finetune_utils.py
tasks/finetune_utils.py
+2
-1
tasks/orqa/evaluate_utils.py
tasks/orqa/evaluate_utils.py
+0
-11
No files found.
megatron/arguments.py
View file @
25293807
...
...
@@ -479,12 +479,6 @@ def _add_learning_rate_args(parser):
group
.
add_argument
(
'--min-lr'
,
type
=
float
,
default
=
0.0
,
help
=
'Minumum value for learning rate. The scheduler'
'clip values below this threshold.'
)
group
.
add_argument
(
'--override-lr-new'
,
action
=
'store_true'
,
help
=
'Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum '
'number of iterations, and decay style from input '
'arguments and ignore values from checkpoints. Note'
'that all the above values will be reset.'
)
group
.
add_argument
(
'--override-lr-scheduler'
,
action
=
'store_true'
,
help
=
'Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum '
...
...
megatron/checkpointing.py
View file @
25293807
...
...
@@ -419,7 +419,6 @@ def load_biencoder_checkpoint(model, only_query_model=False,
assert
len
(
model
)
==
1
model
[
0
].
load_state_dict
(
ret_state_dict
)
torch
.
distributed
.
barrier
()
if
mpu
.
get_data_parallel_rank
()
==
0
:
...
...
megatron/indexer.py
View file @
25293807
...
...
@@ -26,13 +26,10 @@ class IndexBuilder(object):
self
.
evidence_embedder_obj
=
None
self
.
biencoder_shared_query_context_model
=
\
args
.
biencoder_shared_query_context_model
#self.pre_process = True
#self.post_process = True
# need to know whether we're using a REALM checkpoint (args.load)
# or ICT checkpoint
assert
not
(
args
.
load
and
args
.
ict_load
)
#self.using_realm_chkpt = args.ict_load is None
self
.
log_interval
=
args
.
indexer_log_interval
self
.
batch_size
=
args
.
indexer_batch_size
...
...
@@ -46,24 +43,13 @@ class IndexBuilder(object):
"""
Load the necessary attributes: model, dataloader and empty BlockData
"""
#args = get_args()
only_context_model
=
True
if
self
.
biencoder_shared_query_context_model
:
only_context_model
=
False
#args.only_context_model = only_context_model
#args.only_query_model = False
#model = get_model(biencoder_model_provider)
model
=
get_model
(
get_model_provider
(
only_context_model
=
only_context_model
,
biencoder_shared_query_context_model
=
self
.
biencoder_shared_query_context_model
))
#model = get_model(lambda: biencoder_model_provider(only_context_model \
#model = get_model(lambda: biencoder_model_provider(only_context_model \
# = only_context_model, biencoder_shared_query_context_model = \
# self.biencoder_shared_query_context_model,
# pre_process=True, post_process=True)
model
=
get_model
(
get_model_provider
(
only_context_model
=
\
only_context_model
,
biencoder_shared_query_context_model
=
\
self
.
biencoder_shared_query_context_model
))
self
.
model
=
load_biencoder_checkpoint
(
model
,
only_context_model
=
only_context_model
)
...
...
@@ -103,12 +89,7 @@ class IndexBuilder(object):
while
not
hasattr
(
unwrapped_model
,
'embed_text'
):
unwrapped_model
=
unwrapped_model
.
module
#counter = 0
#start_time = time.time()
#cur_time = start_time
while
True
:
#start_time = time.time()
#t1 = time.time()
try
:
# batch also has query_tokens and query_pad_data
row_id
,
context_tokens
,
context_mask
,
context_types
,
\
...
...
@@ -117,8 +98,6 @@ class IndexBuilder(object):
except
(
StopIteration
,
IndexError
):
break
#print_rank_0("get batch time {}".format(cur_time - time.time()))
#t2 = time.time()
# TODO: can we add with torch.no_grad() to reduce memory usage
# detach, separate fields and add to BlockData
assert
context_mask
.
dtype
==
torch
.
bool
...
...
@@ -128,18 +107,10 @@ class IndexBuilder(object):
context_logits
=
detach
(
context_logits
)
row_id
=
detach
(
row_id
)
#print_rank_0("embed text {}".format(cur_time - time.time()))
#t3 = time.time()
self
.
evidence_embedder_obj
.
add_block_data
(
row_id
,
context_logits
)
self
.
track_and_report_progress
(
batch_size
=
len
(
row_id
))
#print_rank_0("add block time {}".format(cur_time - time.time()))
#t4 = time.time()
#counter += 1
#if counter % 1000 == 0:
# print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time))
# print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3))
# cur_time = time.time()
# This process signals to finalize its shard and then synchronize with
# the other processes
self
.
evidence_embedder_obj
.
save_shard
()
...
...
megatron/learning_rates.py
View file @
25293807
...
...
@@ -18,7 +18,6 @@
import
math
from
megatron
import
print_rank_0
from
megatron
import
get_args
class
AnnealingLR
(
object
):
"""Anneals the learning rate."""
...
...
@@ -60,7 +59,6 @@ class AnnealingLR(object):
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
#print_rank_0("self.warmup_steps {} self.num_steps {} self.decay_steps {} self.min_lr {} self.maxlr {}".format(self.warmup_steps, self.num_steps, self.decay_steps, self.min_lr, self.max_lr))
# Use linear warmup for the initial part.
if
self
.
warmup_steps
>
0
and
self
.
num_steps
<=
self
.
warmup_steps
:
return
self
.
max_lr
*
float
(
self
.
num_steps
)
/
\
...
...
@@ -90,20 +88,6 @@ class AnnealingLR(object):
raise
Exception
(
'{} decay style is not supported.'
.
format
(
self
.
decay_style
))
args
=
get_args
()
if
args
.
override_lr_new
:
mod_num_steps_
=
min
(
self
.
num_steps
,
self
.
decay_steps
-
self
.
warmup_steps
)
mod_num_steps_
=
mod_num_steps_
-
self
.
warmup_steps
use_lr
=
delta_lr
*
float
(
self
.
decay_steps
-
mod_num_steps_
)
/
float
(
self
.
decay_steps
)
should_use_lr
=
self
.
min_lr
+
coeff
*
delta_lr
print_rank_0
(
"num_steps {} decay_steps {} decay_ratio {} coeff {} delta_lr {} use lr {} should_use_lr {} self.warmup_steps {} self.num_steps {} self.decay_steps {}"
.
format
(
num_steps_
,
decay_steps_
,
decay_ratio
,
coeff
,
delta_lr
,
use_lr
,
should_use_lr
,
self
.
warmup_steps
,
self
.
num_steps
,
self
.
decay_steps
))
else
:
use_lr
=
self
.
min_lr
+
coeff
*
delta_lr
print_rank_0
(
"num_steps {} decay_steps {} decay_ratio {} coeff {} delta_lr {} use lr {} self.warmup_steps {} self.num_steps {} self.decay_steps {}"
.
format
(
num_steps_
,
decay_steps_
,
decay_ratio
,
coeff
,
delta_lr
,
use_lr
,
self
.
warmup_steps
,
self
.
num_steps
,
self
.
decay_steps
))
return
use_lr
return
self
.
min_lr
+
coeff
*
delta_lr
...
...
tasks/finetune_utils.py
View file @
25293807
...
...
@@ -114,7 +114,8 @@ def _build_infinite_size_dataloader(dataloader):
iterator
=
dataloader
.
__iter__
()
def
_build_train_valid_dataloaders
(
train_dataset
,
valid_dataset
,
task_collate_fn
=
None
):
def
_build_train_valid_dataloaders
(
train_dataset
,
valid_dataset
,
task_collate_fn
=
None
):
"""Traing and validation dataloaders."""
args
=
get_args
()
...
...
tasks/orqa/evaluate_utils.py
View file @
25293807
...
...
@@ -44,20 +44,9 @@ class ORQAEvaluator(object):
if
args
.
biencoder_shared_query_context_model
:
only_query_model
=
False
#args.only_query_model = only_query_model
#args.only_context_model = False
model
=
get_model
(
get_model_provider
(
only_query_model
=
only_query_model
,
biencoder_shared_query_context_model
=
args
.
biencoder_shared_query_context_model
))
#model = get_model(lambda: biencoder_model_provider(only_query_model=\
#model = get_model(lambda: biencoder_model_provider(only_query_model=\
# only_query_model, biencoder_shared_query_context_model=\
# args.biencoder_shared_query_context_model,
# pre_process=True, post_process=True))
#model = get_model(biencoder_model_provider)
self
.
model
=
load_biencoder_checkpoint
(
model
,
only_query_model
=
only_query_model
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment