Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
7577931b
Commit
7577931b
authored
May 18, 2021
by
Mostofa Patwary
Browse files
Fixed issues with ICT pretraining
parent
8e44d619
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
85 deletions
+29
-85
examples/create_embeddings.sh
examples/create_embeddings.sh
+0
-32
pretrain_ict.py
pretrain_ict.py
+29
-20
tools/create_doc_index.py
tools/create_doc_index.py
+0
-33
No files found.
examples/create_embeddings.sh
deleted
100644 → 0
View file @
8e44d619
#!/bin/bash
# Compute embeddings for each entry of a given dataset (e.g. Wikipedia)
RANK
=
0
WORLD_SIZE
=
1
# Wikipedia data can be downloaded from the following link:
# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
EVIDENCE_DATA_DIR
=
<Specify path of Wikipedia dataset>
EMBEDDING_PATH
=
<Specify path to store embeddings>
CHECKPOINT_PATH
=
<Specify path of pretrained ICT model>
python tools/create_doc_index.py
\
--num-layers
12
\
--hidden-size
768
\
--num-attention-heads
12
\
--tensor-model-parallel-size
1
\
--micro-batch-size
128
\
--checkpoint-activations
\
--seq-length
512
\
--retriever-seq-length
256
\
--max-position-embeddings
512
\
--load
${
CHECKPOINT_PATH
}
\
--evidence-data-path
${
EVIDENCE_DATA_DIR
}
\
--embedding-path
${
EMBEDDING_PATH
}
\
--indexer-log-interval
1000
\
--indexer-batch-size
128
\
--vocab-file
bert-vocab.txt
\
--num-workers
2
\
--fp16
pretrain_ict.py
View file @
7577931b
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
# limitations under the License.
# limitations under the License.
"""Pretrain BERT for Inverse Cloze Task"""
"""Pretrain BERT for Inverse Cloze Task"""
from
functools
import
partial
import
math
import
math
import
torch
import
torch
...
@@ -31,14 +33,15 @@ from megatron.training import pretrain
...
@@ -31,14 +33,15 @@ from megatron.training import pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
def
pretrain_ict_model_provider
():
def
pretrain_ict_model_provider
(
pre_process
=
True
,
post_process
=
True
):
args
=
get_args
()
args
=
get_args
()
model
=
biencoder_model_provider
(
model
=
biencoder_model_provider
(
only_context_model
=
False
,
only_context_model
=
False
,
only_query_model
=
False
,
only_query_model
=
False
,
biencoder_shared_query_context_model
=
\
biencoder_shared_query_context_model
=
\
args
.
biencoder_shared_query_context_model
)
args
.
biencoder_shared_query_context_model
,
pre_process
=
pre_process
,
post_process
=
post_process
)
return
model
return
model
...
@@ -79,25 +82,9 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
...
@@ -79,25 +82,9 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
output
=
output_list
[
rank
].
contiguous
()
output
=
output_list
[
rank
].
contiguous
()
return
output
return
output
def
forward_step
(
data_iterator
,
model
,
input_tensor
):
def
loss_func
(
output_tensor
):
"""Forward step."""
args
=
get_args
()
args
=
get_args
()
timers
=
get_timers
()
query_logits
,
context_logits
=
output_tensor
# Get the batch.
timers
(
'batch-generator'
).
start
()
query_tokens
,
query_mask
,
\
context_tokens
,
context_mask
,
context_indices
=
get_ict_batch
(
data_iterator
)
timers
(
'batch-generator'
).
stop
()
# Query and Context Types
query_types
=
torch
.
cuda
.
LongTensor
(
*
query_tokens
.
shape
).
fill_
(
0
)
context_types
=
torch
.
cuda
.
LongTensor
(
*
context_tokens
.
shape
).
fill_
(
0
)
# Forward model.
query_logits
,
context_logits
=
model
(
query_tokens
,
query_mask
,
query_types
,
context_tokens
,
context_mask
,
context_types
)
micro_batch_size
=
query_logits
.
shape
[
0
]
micro_batch_size
=
query_logits
.
shape
[
0
]
# recall we assert that tensor_model_parallel_size == 1
# recall we assert that tensor_model_parallel_size == 1
...
@@ -139,6 +126,28 @@ def forward_step(data_iterator, model, input_tensor):
...
@@ -139,6 +126,28 @@ def forward_step(data_iterator, model, input_tensor):
return
loss
,
stats_dict
return
loss
,
stats_dict
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
args
=
get_args
()
timers
=
get_timers
()
# Get the batch.
timers
(
'batch-generator'
).
start
()
query_tokens
,
query_mask
,
\
context_tokens
,
context_mask
,
context_indices
=
get_ict_batch
(
data_iterator
)
timers
(
'batch-generator'
).
stop
()
# Query and Context Types
query_types
=
torch
.
cuda
.
LongTensor
(
*
query_tokens
.
shape
).
fill_
(
0
)
context_types
=
torch
.
cuda
.
LongTensor
(
*
context_tokens
.
shape
).
fill_
(
0
)
# Forward model.
output_tensor
=
model
(
query_tokens
,
query_mask
,
query_types
,
context_tokens
,
context_mask
,
context_types
)
return
output_tensor
,
partial
(
loss_func
)
def
train_valid_test_datasets_provider
(
train_val_test_num_samples
):
def
train_valid_test_datasets_provider
(
train_val_test_num_samples
):
"""Build train, valid and test datasets."""
"""Build train, valid and test datasets."""
args
=
get_args
()
args
=
get_args
()
...
...
tools/create_doc_index.py
deleted
100644 → 0
View file @
8e44d619
import
os
import
sys
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
os
.
path
.
pardir
)))
from
megatron
import
print_rank_0
from
megatron.indexer
import
IndexBuilder
from
megatron.initialize
import
initialize_megatron
def
main
():
"""Create a BlockData data structure by running an IndexBuilder over an ICT Dataset
- Include all args needed for initial model specification
Other key args:
--block-data-path: path to write to
--ict-load or --realm-load: path to checkpoint with which to embed
--data-path and --titles-data-path: paths for dataset
--indexer-log-interval: reporting interval
--indexer-batch-size: size specific for indexer jobs
Check README.md for example script
"""
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
index_builder
=
IndexBuilder
()
index_builder
.
build_and_save_index
()
print_rank_0
(
"Build and save indices: done!"
)
if
__name__
==
"__main__"
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment