Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
06076c7a
Commit
06076c7a
authored
Apr 23, 2021
by
Mostofa Patwary
Browse files
implementation dpr
parent
cdde4338
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
84 additions
and
16 deletions
+84
-16
megatron/model/biencoder_model.py
megatron/model/biencoder_model.py
+31
-6
megatron/tokenizer/bert_tokenization.py
megatron/tokenizer/bert_tokenization.py
+29
-0
megatron/tokenizer/tokenizer.py
megatron/tokenizer/tokenizer.py
+4
-0
tasks/finetune_utils.py
tasks/finetune_utils.py
+20
-10
No files found.
megatron/model/biencoder_model.py
View file @
06076c7a
...
...
@@ -17,7 +17,9 @@ from .module import MegatronModule
def
biencoder_model_provider
(
only_query_model
=
False
,
only_context_model
=
False
,
biencoder_shared_query_context_model
=
False
):
biencoder_shared_query_context_model
=
False
,
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
args
=
get_args
()
...
...
@@ -35,7 +37,9 @@ def biencoder_model_provider(only_query_model=False,
only_query_model
=
only_query_model
,
only_context_model
=
only_context_model
,
biencoder_shared_query_context_model
=
\
biencoder_shared_query_context_model
)
biencoder_shared_query_context_model
,
pre_process
=
pre_process
,
post_process
=
post_process
)
return
model
...
...
@@ -48,13 +52,17 @@ class BiEncoderModel(MegatronModule):
parallel_output
=
True
,
only_query_model
=
False
,
only_context_model
=
False
,
biencoder_shared_query_context_model
=
False
):
biencoder_shared_query_context_model
=
False
,
pre_process
=
True
,
post_process
=
True
):
super
(
BiEncoderModel
,
self
).
__init__
()
args
=
get_args
()
bert_kwargs
=
dict
(
num_tokentypes
=
num_tokentypes
,
parallel_output
=
parallel_output
)
parallel_output
=
parallel_output
,
pre_process
=
pre_process
,
post_process
=
post_process
)
self
.
biencoder_shared_query_context_model
=
\
biencoder_shared_query_context_model
...
...
@@ -78,6 +86,19 @@ class BiEncoderModel(MegatronModule):
self
.
context_model
=
PretrainedBertModel
(
**
bert_kwargs
)
self
.
_context_key
=
'context_model'
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
#self.language_model.set_input_tensor(input_tensor)
return
# #if self._model_key is not None:
# # print("_model_key {}".format(self._model_key), flush=True)
# print(input_tensor)
# if self._query_key is not None:
# print("_query_key {}".format(self._query_key), flush=True)
# if self._context_key is not None:
# print("_context_key {}".format(self._context_key), flush=True)
# exit()
def
forward
(
self
,
query_tokens
,
query_attention_mask
,
query_types
,
context_tokens
,
context_attention_mask
,
context_types
):
"""Run a forward pass for each of the models and
...
...
@@ -217,7 +238,7 @@ class PretrainedBertModel(MegatronModule):
learned information retrieval."""
def
__init__
(
self
,
num_tokentypes
=
2
,
parallel_output
=
True
):
parallel_output
=
True
,
pre_process
=
True
,
post_process
=
True
):
super
(
PretrainedBertModel
,
self
).
__init__
()
args
=
get_args
()
...
...
@@ -225,6 +246,8 @@ class PretrainedBertModel(MegatronModule):
self
.
pad_id
=
tokenizer
.
pad
self
.
biencoder_projection_dim
=
args
.
biencoder_projection_dim
self
.
parallel_output
=
parallel_output
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
init_method
=
init_method_normal
(
args
.
init_method_std
)
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
...
...
@@ -234,7 +257,9 @@ class PretrainedBertModel(MegatronModule):
add_pooler
=
False
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
)
scaled_init_method
=
scaled_init_method
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
if
args
.
biencoder_projection_dim
>
0
:
self
.
projection_enc
=
get_linear_layer
(
args
.
hidden_size
,
...
...
megatron/tokenizer/bert_tokenization.py
View file @
06076c7a
...
...
@@ -181,6 +181,35 @@ class FullTokenizer(object):
def
convert_ids_to_tokens
(
self
,
ids
):
return
convert_by_vocab
(
self
.
inv_vocab
,
ids
)
@
staticmethod
def
convert_tokens_to_string
(
tokens
,
clean_up_tokenization_spaces
=
True
):
""" Converts a sequence of tokens (string) in a single string. """
def
clean_up_tokenization
(
out_string
):
""" Clean up a list of simple English tokenization artifacts
like spaces before punctuations and abreviated forms.
"""
out_string
=
(
out_string
.
replace
(
" ."
,
"."
)
.
replace
(
" ?"
,
"?"
)
.
replace
(
" !"
,
"!"
)
.
replace
(
" ,"
,
","
)
.
replace
(
" ' "
,
"'"
)
.
replace
(
" n't"
,
"n't"
)
.
replace
(
" 'm"
,
"'m"
)
.
replace
(
" 's"
,
"'s"
)
.
replace
(
" 've"
,
"'ve"
)
.
replace
(
" 're"
,
"'re"
)
)
return
out_string
text
=
' '
.
join
(
tokens
).
replace
(
' ##'
,
''
).
strip
()
if
clean_up_tokenization_spaces
:
clean_text
=
clean_up_tokenization
(
text
)
return
clean_text
else
:
return
text
def
vocab_size
(
self
):
return
len
(
self
.
vocab
)
...
...
megatron/tokenizer/tokenizer.py
View file @
06076c7a
...
...
@@ -155,6 +155,10 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
text_tokens
=
self
.
tokenizer
.
tokenize
(
text
)
return
self
.
tokenizer
.
convert_tokens_to_ids
(
text_tokens
)
def
decode
(
self
,
ids
):
tokens
=
self
.
tokenizer
.
convert_ids_to_tokens
(
ids
)
return
self
.
tokenizer
.
convert_tokens_to_string
(
tokens
)
def
decode_token_ids
(
self
,
token_ids
):
tokens
=
self
.
tokenizer
.
convert_ids_to_tokens
(
token_ids
)
exclude_list
=
[
'[PAD]'
,
'[CLS]'
]
...
...
tasks/finetune_utils.py
View file @
06076c7a
...
...
@@ -80,7 +80,8 @@ def _cross_entropy_forward_step(batch, model):
return
output_tensor
,
partial
(
cross_entropy_loss_func
,
labels
)
def
build_data_loader
(
dataset
,
micro_batch_size
,
num_workers
,
drop_last
):
def
build_data_loader
(
dataset
,
micro_batch_size
,
num_workers
,
drop_last
,
task_collate_fn
=
None
):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler.
...
...
@@ -89,6 +90,8 @@ def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
dataset
,
num_replicas
=
world_size
,
rank
=
rank
)
print_rank_0
(
len
(
sampler
))
# Data loader. Note that batch size is the per GPU batch size.
data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
micro_batch_size
,
...
...
@@ -96,7 +99,8 @@ def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
shuffle
=
False
,
num_workers
=
num_workers
,
drop_last
=
drop_last
,
pin_memory
=
True
)
pin_memory
=
True
,
collate_fn
=
task_collate_fn
)
return
data_loader
...
...
@@ -112,21 +116,23 @@ def _build_infinite_size_dataloader(dataloader):
iterator
=
dataloader
.
__iter__
()
def
_build_train_valid_dataloaders
(
train_dataset
,
valid_dataset
):
def
_build_train_valid_dataloaders
(
train_dataset
,
valid_dataset
,
task_collate_fn
=
None
):
"""Traing and validation dataloaders."""
args
=
get_args
()
print_rank_0
(
'building train and validation dataloaders ...'
)
# Training dataset.
train_dataloader
=
build_data_loader
(
train_dataset
,
args
.
micro_batch_size
,
args
.
num_workers
,
not
args
.
keep_last
)
args
.
num_workers
,
not
args
.
keep_last
,
task_collate_fn
)
# Set the training iterations.
args
.
train_iters_per_epoch
=
len
(
train_dataloader
)
args
.
train_iters
=
args
.
epochs
*
args
.
train_iters_per_epoch
# Validation dataset. For this dataset, we do not need to set up
# shuffling so we can just use a simple infinite loop.
valid_dataloader_
=
build_data_loader
(
valid_dataset
,
args
.
micro_batch_size
,
args
.
num_workers
,
not
args
.
keep_last
)
args
.
num_workers
,
not
args
.
keep_last
,
task_collate_fn
)
valid_dataloader
=
_build_infinite_size_dataloader
(
valid_dataloader_
)
# Now that we've built the data loaders, set batch_size arguments
...
...
@@ -185,9 +191,10 @@ def _train(model, optimizer, lr_scheduler, forward_step,
continue
# Set to zero so the next epoch does not skip any batches.
start_iteration
=
0
# Train for one step.
out
=
train_step
(
forward_step
,
batch
,
model
,
optimizer
,
lr_scheduler
)
losses_dict
,
skipped_iter
,
grad_norm
,
num_zeros_in_grad
=
out
iteration
+=
1
...
...
@@ -220,6 +227,10 @@ def _train(model, optimizer, lr_scheduler, forward_step,
valid_dataloader
,
model
,
iteration
,
False
)
#if iteration == 1000:
# exit()
#break
# Checkpointing at the end of each epoch.
if
args
.
save
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
...
...
@@ -231,7 +242,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
def
finetune
(
train_valid_datasets_provider
,
model_provider
,
forward_step
=
_cross_entropy_forward_step
,
end_of_epoch_callback_provider
=
None
):
end_of_epoch_callback_provider
=
None
,
task_collate_fn
=
None
):
"""Main finetune function used across all tasks."""
args
=
get_args
()
timers
=
get_timers
()
...
...
@@ -244,7 +256,7 @@ def finetune(train_valid_datasets_provider, model_provider,
if
args
.
epochs
>
0
:
train_dataset
,
valid_dataset
=
train_valid_datasets_provider
()
train_dataloader
,
valid_dataloader
=
_build_train_valid_dataloaders
(
train_dataset
,
valid_dataset
)
train_dataset
,
valid_dataset
,
task_collate_fn
)
else
:
args
.
train_iters
=
0
timers
(
'train/valid/test dataset/dataloder'
).
stop
()
...
...
@@ -256,8 +268,6 @@ def finetune(train_valid_datasets_provider, model_provider,
end_of_epoch_callback
=
end_of_epoch_callback_provider
()
timers
(
'callback function'
).
stop
()
exit
()
# Build model, optimizer and learning rate scheduler.
timers
(
'model and optimizer'
).
start
()
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment