Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
220637f9
Commit
220637f9
authored
May 11, 2021
by
Mostofa Patwary
Browse files
DPR evaluation debugging
parent
a8d172b3
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
97 additions
and
17 deletions
+97
-17
megatron/arguments.py
megatron/arguments.py
+6
-0
megatron/checkpointing.py
megatron/checkpointing.py
+5
-2
megatron/indexer.py
megatron/indexer.py
+32
-7
megatron/learning_rates.py
megatron/learning_rates.py
+17
-1
megatron/model/biencoder_model.py
megatron/model/biencoder_model.py
+4
-0
megatron/model/language_model.py
megatron/model/language_model.py
+6
-1
tasks/finetune_utils.py
tasks/finetune_utils.py
+4
-0
tasks/orqa/supervised/finetune.py
tasks/orqa/supervised/finetune.py
+23
-6
No files found.
megatron/arguments.py
View file @
220637f9
...
...
@@ -478,6 +478,12 @@ def _add_learning_rate_args(parser):
group
.
add_argument
(
'--min-lr'
,
type
=
float
,
default
=
0.0
,
help
=
'Minumum value for learning rate. The scheduler'
'clip values below this threshold.'
)
group
.
add_argument
(
'--override-lr-new'
,
action
=
'store_true'
,
help
=
'Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum '
'number of iterations, and decay style from input '
'arguments and ignore values from checkpoints. Note'
'that all the above values will be reset.'
)
group
.
add_argument
(
'--override-lr-scheduler'
,
action
=
'store_true'
,
help
=
'Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum '
...
...
megatron/checkpointing.py
View file @
220637f9
...
...
@@ -413,8 +413,11 @@ def load_biencoder_checkpoint(model, only_query_model=False,
if
only_context_model
:
ret_state_dict
.
pop
(
'query_model'
)
assert
len
(
model
)
==
1
model
[
0
].
load_state_dict
(
ret_state_dict
)
#print_rank_0(len(model))
#sys.exit()
#assert len(model) == 1
#model[0].load_state_dict(ret_state_dict)
model
.
load_state_dict
(
ret_state_dict
)
torch
.
distributed
.
barrier
()
if
mpu
.
get_data_parallel_rank
()
==
0
:
...
...
megatron/indexer.py
View file @
220637f9
...
...
@@ -2,7 +2,7 @@ import sys
import
torch
import
torch.distributed
as
dist
from
megatron
import
get_args
from
megatron
import
get_args
,
print_rank_0
from
megatron
import
mpu
from
megatron.checkpointing
import
load_biencoder_checkpoint
from
megatron.data.orqa_wiki_dataset
import
get_open_retrieval_wiki_dataset
...
...
@@ -25,6 +25,8 @@ class IndexBuilder(object):
self
.
evidence_embedder_obj
=
None
self
.
biencoder_shared_query_context_model
=
\
args
.
biencoder_shared_query_context_model
self
.
pre_process
=
True
self
.
post_process
=
True
# need to know whether we're using a REALM checkpoint (args.load)
# or ICT checkpoint
...
...
@@ -47,15 +49,22 @@ class IndexBuilder(object):
if
self
.
biencoder_shared_query_context_model
:
only_context_model
=
False
model
=
get_model
(
lambda
:
biencoder_model_provider
(
only_context_model
\
#model = get_model(lambda: biencoder_model_provider(only_context_model \
# = only_context_model, biencoder_shared_query_context_model = \
# self.biencoder_shared_query_context_model, \
# pre_process=self.pre_process, post_process=self.post_process))
model
=
biencoder_model_provider
(
only_context_model
\
=
only_context_model
,
biencoder_shared_query_context_model
=
\
self
.
biencoder_shared_query_context_model
))
self
.
biencoder_shared_query_context_model
,
\
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
self
.
model
=
load_biencoder_checkpoint
(
model
,
only_context_model
=
only_context_model
)
assert
len
(
self
.
model
)
==
1
self
.
model
[
0
].
eval
()
#assert len(self.model) == 1
#self.model[0].eval()
self
.
model
.
eval
()
self
.
dataset
=
get_open_retrieval_wiki_dataset
()
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
self
.
dataset
,
\
...
...
@@ -83,10 +92,12 @@ class IndexBuilder(object):
distributed setting will be consolidated by the rank 0 process
and saved as a final pickled BlockData.
"""
assert
len
(
self
.
model
)
==
1
unwrapped_model
=
self
.
model
[
0
]
#assert len(self.model) == 1
#unwrapped_model = self.model[0]
unwrapped_model
=
self
.
model
while
not
hasattr
(
unwrapped_model
,
'embed_text'
):
unwrapped_model
=
unwrapped_model
.
module
print_rank_0
(
"hasattr"
)
while
True
:
try
:
...
...
@@ -97,12 +108,26 @@ class IndexBuilder(object):
except
(
StopIteration
,
IndexError
):
break
print_rank_0
(
context_tokens
)
print_rank_0
(
context_mask
)
print_rank_0
(
context_types
)
#if torch.cuda.is_available():
# print_rank_0("cuda available")
#print_rank_0(torch.cuda.current_device())
#print_rank_0(torch.cuda.get_device_name())
print_rank_0
(
next
(
unwrapped_model
.
parameters
()).
device
)
print_rank_0
(
next
(
unwrapped_model
.
context_model
.
parameters
()).
device
)
#print_rank_0("After get_open_retrieval_batch")
# TODO: can we add with torch.no_grad() to reduce memory usage
# detach, separate fields and add to BlockData
assert
context_mask
.
dtype
==
torch
.
bool
context_logits
=
unwrapped_model
.
embed_text
(
unwrapped_model
.
context_model
,
context_tokens
,
context_mask
,
context_types
)
sys
.
exit
()
context_logits
=
detach
(
context_logits
)
row_id
=
detach
(
row_id
)
...
...
megatron/learning_rates.py
View file @
220637f9
...
...
@@ -18,6 +18,7 @@
import
math
from
megatron
import
print_rank_0
from
megatron
import
get_args
class
AnnealingLR
(
object
):
"""Anneals the learning rate."""
...
...
@@ -59,6 +60,7 @@ class AnnealingLR(object):
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
#print_rank_0("self.warmup_steps {} self.num_steps {} self.decay_steps {} self.min_lr {} self.maxlr {}".format(self.warmup_steps, self.num_steps, self.decay_steps, self.min_lr, self.max_lr))
# Use linear warmup for the initial part.
if
self
.
warmup_steps
>
0
and
self
.
num_steps
<=
self
.
warmup_steps
:
return
self
.
max_lr
*
float
(
self
.
num_steps
)
/
\
...
...
@@ -87,7 +89,21 @@ class AnnealingLR(object):
else
:
raise
Exception
(
'{} decay style is not supported.'
.
format
(
self
.
decay_style
))
args
=
get_args
()
if
args
.
override_lr_new
:
mod_num_steps_
=
min
(
self
.
num_steps
,
self
.
decay_steps
-
self
.
warmup_steps
)
mod_num_steps_
=
mod_num_steps_
-
self
.
warmup_steps
use_lr
=
delta_lr
*
float
(
self
.
decay_steps
-
mod_num_steps_
)
/
float
(
self
.
decay_steps
)
should_use_lr
=
self
.
min_lr
+
coeff
*
delta_lr
print_rank_0
(
"num_steps {} decay_steps {} decay_ratio {} coeff {} delta_lr {} use lr {} should_use_lr {} self.warmup_steps {} self.num_steps {} self.decay_steps {}"
.
format
(
num_steps_
,
decay_steps_
,
decay_ratio
,
coeff
,
delta_lr
,
use_lr
,
should_use_lr
,
self
.
warmup_steps
,
self
.
num_steps
,
self
.
decay_steps
))
else
:
use_lr
=
self
.
min_lr
+
coeff
*
delta_lr
print_rank_0
(
"num_steps {} decay_steps {} decay_ratio {} coeff {} delta_lr {} use lr {} self.warmup_steps {} self.num_steps {} self.decay_steps {}"
.
format
(
num_steps_
,
decay_steps_
,
decay_ratio
,
coeff
,
delta_lr
,
use_lr
,
self
.
warmup_steps
,
self
.
num_steps
,
self
.
decay_steps
))
return
use_lr
return
self
.
min_lr
+
coeff
*
delta_lr
...
...
megatron/model/biencoder_model.py
View file @
220637f9
...
...
@@ -266,6 +266,10 @@ class PretrainedBertModel(MegatronModule):
#extended_attention_mask = bert_extended_attention_mask(attention_mask)
position_ids
=
bert_position_ids
(
input_ids
)
print_rank_0
(
input_ids
.
device
)
print_rank_0
(
position_ids
.
device
)
print_rank_0
(
extended_attention_mask
.
device
)
print_rank_0
(
tokentype_ids
.
device
)
lm_output
=
self
.
language_model
(
input_ids
,
position_ids
,
...
...
megatron/model/language_model.py
View file @
220637f9
...
...
@@ -18,7 +18,7 @@
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron
import
get_args
,
print_rank_0
from
megatron
import
mpu
from
.module
import
MegatronModule
from
megatron.model.enums
import
LayerType
,
AttnMaskType
...
...
@@ -338,6 +338,11 @@ class TransformerLanguageModel(MegatronModule):
get_key_value
=
False
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
print_rank_0
(
"before self.embedding"
)
print_rank_0
(
enc_input_ids
.
device
)
print_rank_0
(
enc_position_ids
.
device
)
print_rank_0
(
tokentype_ids
.
device
)
# Embeddings.
if
self
.
pre_process
:
embedding_output
=
self
.
embedding
(
enc_input_ids
,
enc_position_ids
,
...
...
tasks/finetune_utils.py
View file @
220637f9
...
...
@@ -16,6 +16,7 @@
"""Finetune utilities."""
from
functools
import
partial
import
sys
import
torch
...
...
@@ -225,6 +226,9 @@ def _train(model, optimizer, lr_scheduler, forward_step,
valid_dataloader
,
model
,
iteration
,
False
)
#if iteration == 600:
# sys.exit()
# Checkpointing at the end of each epoch.
if
args
.
save
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
...
...
tasks/orqa/supervised/finetune.py
View file @
220637f9
...
...
@@ -34,6 +34,8 @@ from pretrain_ict import get_group_world_size_rank
from
tasks.finetune_utils
import
finetune
from
tasks.orqa.supervised.eval_utils
import
accuracy_func_provider
from
tasks.orqa.supervised.eval_utils
import
process_batch
,
task_collate_fn
from
tasks.orqa.evaluate_utils
import
ORQAEvaluator
from
megatron.indexer
import
IndexBuilder
def
orqa
(
Dataset
):
# , name_from_datapath_func):
...
...
@@ -226,14 +228,29 @@ def orqa(Dataset): # , name_from_datapath_func):
def
main
():
args
=
get_args
()
if
args
.
task
==
'RET-FINETUNE-NQ'
:
from
tasks.orqa.supervised.data
import
NQSupervisedDataset
as
Dataset
#
if args.task == 'RET-FINETUNE-NQ':
#
from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset
#def name_from_datapath(datapath):
# return datapath[0].split('/')[-1].split('.')[0]
else
:
raise
NotImplementedError
(
'ORQA task {} is not implemented.'
.
format
(
args
.
task
))
#else:
# raise NotImplementedError('ORQA task {} is not implemented.'.format(
# args.task))
#orqa(Dataset) #, name_from_datapath)
index_builder
=
IndexBuilder
()
index_builder
.
build_and_save_index
()
print_rank_0
(
"Build and save indices: done!"
)
# Set up the model and evaluator
#evaluator = ORQAEvaluator()
# Run evaluation
#if args.qa_data_dev is not None:
# evaluator.evaluate(args.qa_data_dev, "DEV")
#if args.qa_data_test is not None:
# evaluator.evaluate(args.qa_data_test, "TEST")
orqa
(
Dataset
)
#, name_from_datapath)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment