Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
41a64613
Commit
41a64613
authored
Mar 04, 2021
by
Rewon Child
Browse files
Merge main
parents
8676baca
83d26f03
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1824 additions
and
153 deletions
+1824
-153
examples/create_embeddings.sh
examples/create_embeddings.sh
+32
-0
examples/pretrain_ict.sh
examples/pretrain_ict.sh
+44
-0
megatron/arguments.py
megatron/arguments.py
+43
-8
megatron/checkpointing.py
megatron/checkpointing.py
+63
-44
megatron/data/biencoder_dataset_utils.py
megatron/data/biencoder_dataset_utils.py
+211
-0
megatron/data/data_samplers.py
megatron/data/data_samplers.py
+14
-4
megatron/data/ict_dataset.py
megatron/data/ict_dataset.py
+20
-4
megatron/data/orqa_wiki_dataset.py
megatron/data/orqa_wiki_dataset.py
+205
-0
megatron/data/realm_index.py
megatron/data/realm_index.py
+31
-28
megatron/indexer.py
megatron/indexer.py
+79
-43
megatron/initialize.py
megatron/initialize.py
+2
-1
megatron/model/__init__.py
megatron/model/__init__.py
+0
-2
megatron/model/biencoder_model.py
megatron/model/biencoder_model.py
+295
-0
megatron/model/module.py
megatron/model/module.py
+2
-2
megatron/model/transformer.py
megatron/model/transformer.py
+21
-1
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+3
-0
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+46
-3
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+14
-13
megatron/p2p_communication.py
megatron/p2p_communication.py
+264
-0
megatron/schedules.py
megatron/schedules.py
+435
-0
No files found.
examples/create_embeddings.sh
0 → 100644
View file @
41a64613
#!/bin/bash
# Compute embeddings for each entry of a given dataset (e.g. Wikipedia)
RANK
=
0
WORLD_SIZE
=
1
# Wikipedia data can be downloaded from the following link:
# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
EVIDENCE_DATA_DIR
=
<Specify path of Wikipedia dataset>
EMBEDDING_PATH
=
<Specify path to store embeddings>
CHECKPOINT_PATH
=
<Specify path of pretrained ICT model>
python tools/create_doc_index.py
\
--num-layers
12
\
--hidden-size
768
\
--num-attention-heads
12
\
--tensor-model-parallel-size
1
\
--micro-batch-size
128
\
--checkpoint-activations
\
--seq-length
512
\
--retriever-seq-length
256
\
--max-position-embeddings
512
\
--load
${
CHECKPOINT_PATH
}
\
--evidence-data-path
${
EVIDENCE_DATA_DIR
}
\
--embedding-path
${
EMBEDDING_PATH
}
\
--indexer-log-interval
1000
\
--indexer-batch-size
128
\
--vocab-file
bert-vocab.txt
\
--num-workers
2
\
--fp16
examples/pretrain_ict.sh
0 → 100755
View file @
41a64613
#! /bin/bash
# Runs the "217M" parameter biencoder model for ICT retriever
RANK
=
0
WORLD_SIZE
=
1
PRETRAINED_BERT_PATH
=
<Specify path of pretrained BERT model>
TEXT_DATA_PATH
=
<Specify path and file prefix of the text data>
TITLE_DATA_PATH
=
<Specify path and file prefix
od
the titles>
CHECKPOINT_PATH
=
<Specify path>
python pretrain_ict.py
\
--num-layers
12
\
--hidden-size
768
\
--num-attention-heads
12
\
--tensor-model-parallel-size
1
\
--micro-batch-size
32
\
--seq-length
256
\
--max-position-embeddings
512
\
--train-iters
100000
\
--vocab-file
bert-vocab.txt
\
--tokenizer-type
BertWordPieceLowerCase
\
--DDP-impl
torch
\
--bert-load
${
PRETRAINED_BERT_PATH
}
\
--log-interval
100
\
--eval-interval
1000
\
--eval-iters
10
\
--retriever-report-topk-accuracies
1 5 10 20 100
\
--retriever-score-scaling
\
--load
$CHECKPOINT_PATH
\
--save
$CHECKPOINT_PATH
\
--data-path
${
TEXT_DATA_PATH
}
\
--titles-data-path
${
TITLE_DATA_PATH
}
\
--lr
0.0001
\
--lr-decay-style
linear
\
--weight-decay
1e-2
\
--clip-grad
1.0
\
--lr-warmup-fraction
0.01
\
--save-interval
4000
\
--exit-interval
8000
\
--query-in-block-prob
0.1
\
--fp16
megatron/arguments.py
View file @
41a64613
...
...
@@ -39,7 +39,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser
=
_add_validation_args
(
parser
)
parser
=
_add_data_args
(
parser
)
parser
=
_add_autoresume_args
(
parser
)
parser
=
_add_
realm
_args
(
parser
)
parser
=
_add_
biencoder
_args
(
parser
)
parser
=
_add_vit_args
(
parser
)
parser
=
_add_logging_args
(
parser
)
...
...
@@ -70,7 +70,7 @@ def parse_args(extra_args_provider=None, defaults={},
model_parallel_size
=
args
.
pipeline_model_parallel_size
*
\
args
.
tensor_model_parallel_size
assert
args
.
world_size
%
model_parallel_size
==
0
,
'world size is not'
\
' divisible by tensor parallel size ({}) times pipeline paralle '
\
' divisible by tensor parallel size ({}) times pipeline paralle
l
'
\
'size ({})'
.
format
(
args
.
world_size
,
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
)
args
.
data_parallel_size
=
args
.
world_size
//
model_parallel_size
...
...
@@ -116,6 +116,15 @@ def parse_args(extra_args_provider=None, defaults={},
print
(
'setting global batch size to {}'
.
format
(
args
.
global_batch_size
),
flush
=
True
)
assert
args
.
global_batch_size
>
0
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
assert
args
.
num_layers
%
args
.
num_layers_per_virtual_pipeline_stage
==
0
,
\
'number of layers is not divisible by number of layers per virtual '
\
'pipeline stage'
args
.
virtual_pipeline_model_parallel_size
=
\
(
args
.
num_layers
//
args
.
pipeline_model_parallel_size
)
//
\
args
.
num_layers_per_virtual_pipeline_stage
else
:
args
.
virtual_pipeline_model_parallel_size
=
None
# Parameters dtype.
args
.
params_dtype
=
torch
.
float
...
...
@@ -214,7 +223,7 @@ def parse_args(extra_args_provider=None, defaults={},
custom_kernel_constraint
=
seq_len
>
16
and
seq_len
<=
2048
and
\
seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
if
args
.
fp16
and
custom_kernel_constraint
and
args
.
masked_softmax_fusion
:
if
not
(
args
.
fp16
and
custom_kernel_constraint
and
args
.
masked_softmax_fusion
)
:
print
(
'WARNING: constraints for invoking optimized'
' fused softmax kernel are not met. We default back to unfused'
' kernel invocations.'
)
...
...
@@ -559,6 +568,8 @@ def _add_distributed_args(parser):
group
.
add_argument
(
'--model-parallel-size'
,
type
=
int
,
default
=
None
,
help
=
'Old model parallel argument, do not use. Use '
'--tensor-model-parallel-size instead.'
)
group
.
add_argument
(
'--num-layers-per-virtual-pipeline-stage'
,
type
=
int
,
default
=
None
,
help
=
'Number of layers per virtual pipeline stage'
)
group
.
add_argument
(
'--distributed-backend'
,
default
=
'nccl'
,
choices
=
[
'nccl'
,
'gloo'
],
help
=
'Which backend to use for distributed training.'
)
...
...
@@ -566,6 +577,9 @@ def _add_distributed_args(parser):
choices
=
[
'local'
,
'torch'
],
help
=
'which DistributedDataParallel implementation '
'to use.'
)
group
.
add_argument
(
'--no-scatter-gather-tensors-in-pipeline'
,
action
=
'store_false'
,
help
=
'Use scatter/gather to optimize communication of tensors in pipeline'
,
dest
=
'scatter_gather_tensors_in_pipeline'
)
group
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
None
,
help
=
'local rank passed from distributed launcher.'
)
group
.
add_argument
(
'--lazy-mpu-init'
,
type
=
bool
,
required
=
False
,
...
...
@@ -617,6 +631,12 @@ def _add_data_args(parser):
'This should be exclusive of --seq-length'
)
group
.
add_argument
(
'--decoder-seq-length'
,
type
=
int
,
default
=
None
,
help
=
"Maximum decoder sequence length to process."
)
group
.
add_argument
(
'--retriever-seq-length'
,
type
=
int
,
default
=
256
,
help
=
'Maximum sequence length for the biencoder model '
' for retriever'
)
group
.
add_argument
(
'--sample-rate'
,
type
=
float
,
default
=
1.0
,
help
=
'sample rate for training data. Supposed to be 0 '
' < sample_rate < 1'
)
group
.
add_argument
(
'--mask-prob'
,
type
=
float
,
default
=
0.15
,
help
=
'Probability of replacing a token with mask.'
)
group
.
add_argument
(
'--short-seq-prob'
,
type
=
float
,
default
=
0.1
,
...
...
@@ -657,13 +677,19 @@ def _add_autoresume_args(parser):
return
parser
def
_add_
realm
_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'
realm
'
)
def
_add_
biencoder
_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'
biencoder
'
)
# network size
group
.
add_argument
(
'--ict-head-size'
,
type
=
int
,
default
=
None
,
help
=
'Size of block embeddings to be used in ICT and '
'REALM (paper default: 128)'
)
'REALM (paper default: 128)'
)
group
.
add_argument
(
'--biencoder-projection-dim'
,
type
=
int
,
default
=
0
,
help
=
'Size of projection head used in biencoder (paper'
' default: 128)'
)
group
.
add_argument
(
'--biencoder-shared-query-context-model'
,
action
=
'store_true'
,
help
=
'Whether to share the parameters of the query '
'and context models or not'
)
# checkpointing
group
.
add_argument
(
'--ict-load'
,
type
=
str
,
default
=
None
,
...
...
@@ -680,16 +706,25 @@ def _add_realm_args(parser):
'ICT dataset'
)
group
.
add_argument
(
'--use-one-sent-docs'
,
action
=
'store_true'
,
help
=
'Whether to use one sentence documents in ICT'
)
group
.
add_argument
(
'--evidence-data-path'
,
type
=
str
,
default
=
None
,
help
=
'Path to Wikipedia Evidence frm DPR paper'
)
# training
group
.
add_argument
(
'--report-topk-accuracies'
,
nargs
=
'+'
,
default
=
[],
help
=
"Which top-k accuracies to report (e.g. '1 5 20')"
)
group
.
add_argument
(
'--retriever-report-topk-accuracies'
,
nargs
=
'+'
,
type
=
int
,
default
=
[],
help
=
"Which top-k accuracies to report "
"(e.g. '1 5 20')"
)
group
.
add_argument
(
'--retriever-score-scaling'
,
action
=
'store_true'
,
help
=
'Whether to scale retriever scores by inverse '
'square root of hidden size'
)
# faiss index
group
.
add_argument
(
'--faiss-use-gpu'
,
action
=
'store_true'
,
help
=
'Whether create the FaissMIPSIndex on GPU'
)
group
.
add_argument
(
'--block-data-path'
,
type
=
str
,
default
=
None
,
help
=
'Where to save/load BlockData to/from'
)
group
.
add_argument
(
'--embedding-path'
,
type
=
str
,
default
=
None
,
help
=
'Where to save/load Open-Retrieval Embedding'
' data to/from'
)
# indexer
group
.
add_argument
(
'--indexer-batch-size'
,
type
=
int
,
default
=
128
,
...
...
megatron/checkpointing.py
View file @
41a64613
...
...
@@ -21,12 +21,12 @@ import sys
import
numpy
as
np
import
torch
from
torch.nn.parallel
import
DistributedDataParallel
as
torchDDP
from
megatron
import
(
get_args
,
mpu
,
print_rank_0
,
update_num_microbatches
)
update_num_microbatches
,
utils
)
_CHECKPOINT_VERSION
=
None
...
...
@@ -111,8 +111,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
args
=
get_args
()
# Only rank zero of the data parallel writes to the disk.
if
isinstance
(
model
,
torchDDP
):
model
=
model
.
module
model
=
utils
.
unwrap_model
(
model
)
print_rank_0
(
'saving checkpoint at iteration {:7d} to {}'
.
format
(
iteration
,
args
.
save
))
...
...
@@ -124,7 +123,12 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
state_dict
[
'args'
]
=
args
state_dict
[
'checkpoint_version'
]
=
3.0
state_dict
[
'iteration'
]
=
iteration
state_dict
[
'model'
]
=
model
.
state_dict_for_save_checkpoint
()
if
len
(
model
)
==
1
:
state_dict
[
'model'
]
=
model
[
0
].
state_dict_for_save_checkpoint
()
else
:
for
i
in
range
(
len
(
model
)):
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
state_dict
[
'model%d'
%
i
]
=
model
[
i
].
state_dict_for_save_checkpoint
()
# Optimizer stuff.
if
not
args
.
no_save_optim
:
...
...
@@ -202,6 +206,33 @@ def _transpose_first_dim(t, num_splits, num_splits_first, model):
return
t
def
fix_query_key_value_ordering
(
model
,
checkpoint_version
):
"""Fix up query/key/value matrix ordering if checkpoint
version is smaller than 2.0
"""
if
checkpoint_version
<
2.0
:
for
name
,
param
in
model
.
named_parameters
():
if
name
.
endswith
((
'.query_key_value.weight'
,
'.query_key_value.bias'
)):
if
checkpoint_version
==
0
:
fixed_param
=
_transpose_first_dim
(
param
.
data
,
3
,
True
,
model
)
elif
checkpoint_version
==
1.0
:
fixed_param
=
_transpose_first_dim
(
param
.
data
,
3
,
False
,
model
)
else
:
print_rank_0
(
f
"Invalid checkpoint version
{
checkpoint_version
}
."
)
sys
.
exit
()
param
.
data
.
copy_
(
fixed_param
)
if
name
.
endswith
((
'.key_value.weight'
,
'.key_value.bias'
)):
if
checkpoint_version
==
0
:
fixed_param
=
_transpose_first_dim
(
param
.
data
,
2
,
True
,
model
)
elif
checkpoint_version
==
1.0
:
fixed_param
=
_transpose_first_dim
(
param
.
data
,
2
,
False
,
model
)
else
:
print_rank_0
(
f
"Invalid checkpoint version
{
checkpoint_version
}
."
)
sys
.
exit
()
param
.
data
.
copy_
(
fixed_param
)
print_rank_0
(
" succesfully fixed query-key-values ordering for"
" checkpoint version {}"
.
format
(
checkpoint_version
))
def
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
,
load_arg
=
'load'
,
strict
=
True
):
"""Load a model checkpoint and return the iteration.
strict (bool): whether to strictly enforce that the keys in
...
...
@@ -211,8 +242,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
args
=
get_args
()
load_dir
=
getattr
(
args
,
load_arg
)
if
isinstance
(
model
,
torchDDP
):
model
=
model
.
module
model
=
utils
.
unwrap_model
(
model
)
# Read the tracker file and set the iteration.
tracker_filename
=
get_checkpoint_tracker_filename
(
load_dir
)
...
...
@@ -297,30 +328,17 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
print_rank_0
(
'could not find arguments in the checkpoint ...'
)
# Model.
model
.
load_state_dict
(
state_dict
[
'model'
],
strict
=
strict
)
if
len
(
model
)
==
1
:
model
[
0
].
load_state_dict
(
state_dict
[
'model'
],
strict
=
strict
)
else
:
for
i
in
range
(
len
(
model
)):
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
model
[
i
].
load_state_dict
(
state_dict
[
'model%d'
%
i
],
strict
=
strict
)
# Fix up query/key/value matrix ordering
if
get_checkpoint_version
()
<
2.0
:
checkpoint_version
=
get_checkpoint_version
()
for
name
,
param
in
model
.
named_parameters
():
if
name
.
endswith
((
'.query_key_value.weight'
,
'.query_key_value.bias'
)):
if
checkpoint_version
==
0
:
fixed_param
=
_transpose_first_dim
(
param
.
data
,
3
,
True
,
model
)
elif
checkpoint_version
==
1.0
:
fixed_param
=
_transpose_first_dim
(
param
.
data
,
3
,
False
,
model
)
else
:
print_rank_0
(
f
"Invalid checkpoint version
{
checkpoint_version
}
."
)
sys
.
exit
()
param
.
data
.
copy_
(
fixed_param
)
if
name
.
endswith
((
'.key_value.weight'
,
'.key_value.bias'
)):
if
checkpoint_version
==
0
:
fixed_param
=
_transpose_first_dim
(
param
.
data
,
2
,
True
,
model
)
elif
checkpoint_version
==
1.0
:
fixed_param
=
_transpose_first_dim
(
param
.
data
,
2
,
False
,
model
)
else
:
print_rank_0
(
f
"Invalid checkpoint version
{
checkpoint_version
}
."
)
sys
.
exit
()
param
.
data
.
copy_
(
fixed_param
)
# Fix up query/key/value matrix ordering if needed
checkpoint_version
=
get_checkpoint_version
()
print_rank_0
(
f
' checkpoint version
{
checkpoint_version
}
'
)
fix_query_key_value_ordering
(
model
,
checkpoint_version
)
# Optimizer.
if
not
release
and
not
args
.
finetune
and
not
args
.
no_load_optim
:
...
...
@@ -365,41 +383,42 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
return
iteration
def
load_ict_checkpoint
(
model
,
only_query_model
=
False
,
only_block_model
=
False
,
from_realm_chkpt
=
False
):
"""selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints"""
def
load_biencoder_checkpoint
(
model
,
only_query_model
=
False
,
only_context_model
=
False
,
custom_load_path
=
None
):
"""
selectively load retrieval models for indexing/retrieving
from saved checkpoints
"""
args
=
get_args
()
if
isinstance
(
model
,
torchDDP
):
model
=
model
.
module
model
=
utils
.
unwrap_model
(
model
)
load_path
=
args
.
load
if
from_realm_chkpt
else
args
.
ict_
load
load_path
=
custom_load_path
if
custom_load_path
is
not
None
else
args
.
load
tracker_filename
=
get_checkpoint_tracker_filename
(
load_path
)
with
open
(
tracker_filename
,
'r'
)
as
f
:
iteration
=
int
(
f
.
read
().
strip
())
# assert iteration > 0
checkpoint_name
=
get_checkpoint_name
(
load_path
,
iteration
,
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
ict_state_dict
=
state_dict
[
'model'
]
if
from_realm_chkpt
and
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
" loading ICT state dict from REALM"
,
flush
=
True
)
ict_state_dict
=
ict_state_dict
[
'retriever'
][
'ict_model'
]
ret_state_dict
=
state_dict
[
'model'
]
if
only_query_model
:
ic
t_state_dict
.
pop
(
'context_model'
)
if
only_
block
_model
:
ic
t_state_dict
.
pop
(
'que
stion
_model'
)
re
t_state_dict
.
pop
(
'context_model'
)
if
only_
context
_model
:
re
t_state_dict
.
pop
(
'que
ry
_model'
)
model
.
load_state_dict
(
ict_state_dict
)
assert
len
(
model
)
==
1
model
[
0
].
load_state_dict
(
ret_state_dict
)
torch
.
distributed
.
barrier
()
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
return
model
megatron/data/biencoder_dataset_utils.py
0 → 100644
View file @
41a64613
import
os
import
time
import
numpy
as
np
import
torch
from
megatron
import
get_args
,
get_tokenizer
,
mpu
,
print_rank_0
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
\
pad_and_convert_to_numpy
from
megatron.data.data_samplers
import
MegatronPretrainingSampler
def
make_attention_mask
(
source_block
,
target_block
):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask
=
(
target_block
[
None
,
:]
>=
1
)
*
(
source_block
[:,
None
]
>=
1
)
mask
=
mask
.
astype
(
np
.
int64
)
# (source_length, target_length)
return
mask
def
get_one_epoch_dataloader
(
dataset
,
micro_batch_size
=
None
):
"""Specifically one epoch to be used in an indexing job."""
args
=
get_args
()
world_size
=
mpu
.
get_data_parallel_world_size
()
rank
=
mpu
.
get_data_parallel_rank
()
if
micro_batch_size
is
None
:
micro_batch_size
=
args
.
micro_batch_size
global_batch_size
=
micro_batch_size
*
world_size
num_workers
=
args
.
num_workers
# Use megatron's sampler with consumed samples set to 0 as
# this is only for evaluation and don't intend to resume half way.
# Also, set the drop last to false as don't intend to remove
# the last batch
batch_sampler
=
MegatronPretrainingSampler
(
total_samples
=
len
(
dataset
),
consumed_samples
=
0
,
micro_batch_size
=
args
.
micro_batch_size
,
data_parallel_rank
=
mpu
.
get_data_parallel_rank
(),
data_parallel_size
=
mpu
.
get_data_parallel_world_size
(),
drop_last
=
False
)
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
num_workers
,
pin_memory
=
True
)
def
get_ict_batch
(
data_iterator
):
# Items and their type.
keys
=
[
'query_tokens'
,
'query_mask'
,
'context_tokens'
,
'context_mask'
,
'block_data'
]
datatype
=
torch
.
int64
# Broadcast data.
if
data_iterator
is
None
:
data
=
None
else
:
data
=
next
(
data_iterator
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
query_tokens
=
data_b
[
'query_tokens'
].
long
()
query_mask
=
data_b
[
'query_mask'
]
<
0.5
context_tokens
=
data_b
[
'context_tokens'
].
long
()
context_mask
=
data_b
[
'context_mask'
]
<
0.5
block_indices
=
data_b
[
'block_data'
].
long
()
return
query_tokens
,
query_mask
,
\
context_tokens
,
context_mask
,
block_indices
def
join_str_list
(
str_list
):
"""Join a list of strings, handling spaces appropriately"""
result
=
""
for
s
in
str_list
:
if
s
.
startswith
(
"##"
):
result
+=
s
[
2
:]
else
:
result
+=
" "
+
s
return
result
class
BlockSampleData
(
object
):
"""A struct for fully describing a fixed-size block of data as used in REALM
:param start_idx: for first sentence of the block
:param end_idx: for last sentence of the block (may be partially truncated in sample construction)
:param doc_idx: the index of the document from which the block comes in the original indexed dataset
:param block_idx: a unique integer identifier given to every block.
"""
def
__init__
(
self
,
start_idx
,
end_idx
,
doc_idx
,
block_idx
):
self
.
start_idx
=
start_idx
self
.
end_idx
=
end_idx
self
.
doc_idx
=
doc_idx
self
.
block_idx
=
block_idx
def
as_array
(
self
):
return
np
.
array
([
self
.
start_idx
,
self
.
end_idx
,
self
.
doc_idx
,
self
.
block_idx
]).
astype
(
np
.
int64
)
def
as_tuple
(
self
):
return
self
.
start_idx
,
self
.
end_idx
,
self
.
doc_idx
,
self
.
block_idx
class
BlockSamplesMapping
(
object
):
def
__init__
(
self
,
mapping_array
):
# make sure that the array is compatible with BlockSampleData
assert
mapping_array
.
shape
[
1
]
==
4
self
.
mapping_array
=
mapping_array
def
__len__
(
self
):
return
self
.
mapping_array
.
shape
[
0
]
def
__getitem__
(
self
,
idx
):
"""Get the data associated with an indexed sample."""
sample_data
=
BlockSampleData
(
*
self
.
mapping_array
[
idx
])
return
sample_data
def
get_block_samples_mapping
(
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
seed
,
name
,
use_one_sent_docs
=
False
):
"""Get samples mapping for a dataset over fixed size blocks. This function also requires
a dataset of the titles for the source documents since their lengths must be taken into account.
:return: samples_mapping (BlockSamplesMapping)
"""
if
not
num_epochs
:
if
not
max_num_samples
:
raise
ValueError
(
"Need to specify either max_num_samples "
"or num_epochs"
)
num_epochs
=
np
.
iinfo
(
np
.
int32
).
max
-
1
if
not
max_num_samples
:
max_num_samples
=
np
.
iinfo
(
np
.
int64
).
max
-
1
# Filename of the index mapping
indexmap_filename
=
data_prefix
indexmap_filename
+=
'_{}_indexmap'
.
format
(
name
)
if
num_epochs
!=
(
np
.
iinfo
(
np
.
int32
).
max
-
1
):
indexmap_filename
+=
'_{}ep'
.
format
(
num_epochs
)
if
max_num_samples
!=
(
np
.
iinfo
(
np
.
int64
).
max
-
1
):
indexmap_filename
+=
'_{}mns'
.
format
(
max_num_samples
)
indexmap_filename
+=
'_{}msl'
.
format
(
max_seq_length
)
indexmap_filename
+=
'_{}s'
.
format
(
seed
)
if
use_one_sent_docs
:
indexmap_filename
+=
'_1sentok'
indexmap_filename
+=
'.npy'
# Build the indexed mapping if not exist.
if
mpu
.
get_data_parallel_rank
()
==
0
and
\
not
os
.
path
.
isfile
(
indexmap_filename
):
print
(
' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
# Make sure the types match the helpers input types.
assert
block_dataset
.
doc_idx
.
dtype
==
np
.
int64
assert
block_dataset
.
sizes
.
dtype
==
np
.
int32
# Build samples mapping
verbose
=
torch
.
distributed
.
get_rank
()
==
0
start_time
=
time
.
time
()
print_rank_0
(
' > building samples index mapping for {} ...'
.
format
(
name
))
from
megatron.data
import
helpers
mapping_array
=
helpers
.
build_blocks_mapping
(
block_dataset
.
doc_idx
,
block_dataset
.
sizes
,
title_dataset
.
sizes
,
num_epochs
,
max_num_samples
,
max_seq_length
-
3
,
# account for added tokens
seed
,
verbose
,
use_one_sent_docs
)
print_rank_0
(
' > done building samples index mapping'
)
np
.
save
(
indexmap_filename
,
mapping_array
,
allow_pickle
=
True
)
print_rank_0
(
' > saved the index mapping in {}'
.
format
(
indexmap_filename
))
# Make sure all the ranks have built the mapping
print_rank_0
(
' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts
=
torch
.
cuda
.
LongTensor
([
1
])
torch
.
distributed
.
all_reduce
(
counts
,
group
=
mpu
.
get_data_parallel_group
())
assert
counts
[
0
].
item
()
==
torch
.
distributed
.
get_world_size
(
group
=
mpu
.
get_data_parallel_group
())
# Load indexed dataset.
print_rank_0
(
' > loading indexed mapping from {}'
.
format
(
indexmap_filename
))
start_time
=
time
.
time
()
mapping_array
=
np
.
load
(
indexmap_filename
,
allow_pickle
=
True
,
mmap_mode
=
'r'
)
samples_mapping
=
BlockSamplesMapping
(
mapping_array
)
print_rank_0
(
' loaded indexed file in {:3.3f} seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' total number of samples: {}'
.
format
(
mapping_array
.
shape
[
0
]))
return
samples_mapping
megatron/data/data_samplers.py
View file @
41a64613
...
...
@@ -57,7 +57,7 @@ def build_pretraining_data_loader(dataset, consumed_samples):
class
MegatronPretrainingSampler
:
def
__init__
(
self
,
total_samples
,
consumed_samples
,
micro_batch_size
,
data_parallel_rank
,
data_parallel_size
):
data_parallel_rank
,
data_parallel_size
,
drop_last
=
True
):
# Keep a copy of input params for later use.
self
.
total_samples
=
total_samples
self
.
consumed_samples
=
consumed_samples
...
...
@@ -65,6 +65,7 @@ class MegatronPretrainingSampler:
self
.
data_parallel_rank
=
data_parallel_rank
self
.
micro_batch_times_data_parallel_size
=
\
self
.
micro_batch_size
*
data_parallel_size
self
.
drop_last
=
drop_last
# Sanity checks.
assert
self
.
total_samples
>
0
,
\
...
...
@@ -81,17 +82,26 @@ class MegatronPretrainingSampler:
def
__len__
(
self
):
return
self
.
total_samples
def
get_start_end_idx
(
self
):
start_idx
=
self
.
data_parallel_rank
*
self
.
micro_batch_size
end_idx
=
start_idx
+
self
.
micro_batch_size
return
start_idx
,
end_idx
def
__iter__
(
self
):
batch
=
[]
# Last batch
if not complete will be dropped.
# Last batch
will be dropped if drop_last is not set False
for
idx
in
range
(
self
.
consumed_samples
,
self
.
total_samples
):
batch
.
append
(
idx
)
if
len
(
batch
)
==
self
.
micro_batch_times_data_parallel_size
:
start_idx
=
self
.
data_parallel_rank
*
self
.
micro_batch_size
end_idx
=
start_idx
+
self
.
micro_batch_size
start_idx
,
end_idx
=
self
.
get_start_end_idx
()
yield
batch
[
start_idx
:
end_idx
]
batch
=
[]
# Check the last partial batch and see drop_last is set
if
len
(
batch
)
>
0
and
not
self
.
drop_last
:
start_idx
,
end_idx
=
self
.
get_start_end_idx
()
yield
batch
[
start_idx
:
end_idx
]
class
MegatronPretrainingRandomSampler
:
...
...
megatron/data/ict_dataset.py
View file @
41a64613
...
...
@@ -9,6 +9,16 @@ from megatron import get_args
from
megatron.data.dataset_utils
import
get_indexed_dataset_
from
megatron.data.realm_dataset_utils
import
get_block_samples_mapping
def
make_attention_mask
(
source_block
,
target_block
):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask
=
(
target_block
[
None
,
:]
>=
1
)
*
(
source_block
[:,
None
]
>=
1
)
mask
=
mask
.
astype
(
np
.
int64
)
# (source_length, target_length)
return
mask
def
get_ict_dataset
(
use_titles
=
True
,
query_in_block_prob
=
1
):
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
...
...
@@ -39,7 +49,7 @@ class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def
__init__
(
self
,
name
,
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
query_in_block_prob
,
seed
,
use_titles
=
True
,
use_one_sent_docs
=
False
):
seed
,
use_titles
=
True
,
use_one_sent_docs
=
False
,
binary_head
=
False
):
self
.
name
=
name
self
.
seed
=
seed
self
.
max_seq_length
=
max_seq_length
...
...
@@ -93,14 +103,20 @@ class ICTDataset(Dataset):
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
title_pad_offset
]
query_tokens
,
query_pad_mask
=
self
.
concat_and_pad_tokens
(
query
)
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
context_tokens
,
context_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
query_mask
=
make_attention_mask
(
query_tokens
,
query_tokens
)
context_mask
=
make_attention_mask
(
context_tokens
,
context_tokens
)
block_data
=
sample_data
.
as_array
()
sample
=
{
'query_tokens'
:
query_tokens
,
'query_mask'
:
query_mask
,
'query_pad_mask'
:
query_pad_mask
,
'block_tokens'
:
block_tokens
,
'block_pad_mask'
:
block_pad_mask
,
'context_tokens'
:
context_tokens
,
'context_mask'
:
context_mask
,
'context_pad_mask'
:
context_pad_mask
,
'block_data'
:
block_data
,
}
...
...
megatron/data/orqa_wiki_dataset.py
0 → 100644
View file @
41a64613
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wikipedia dataset from DPR code for ORQA."""
from
abc
import
ABC
import
csv
import
numpy
as
np
import
random
import
torch
from
torch.utils.data
import
Dataset
from
megatron
import
print_rank_0
,
get_args
,
get_tokenizer
,
mpu
from
megatron.data.biencoder_dataset_utils
import
make_attention_mask
def
get_open_retrieval_wiki_dataset
():
args
=
get_args
()
tokenizer
=
get_tokenizer
()
dataset
=
OpenRetrievalEvidenceDataset
(
'2018 Wikipedia from DPR codebase'
,
'evidence'
,
args
.
evidence_data_path
,
tokenizer
,
args
.
retriever_seq_length
)
return
dataset
def
get_open_retrieval_batch
(
data_iterator
):
# Items and their type.
keys
=
[
'row_id'
,
'context'
,
'context_mask'
,
'context_types'
,
'context_pad_mask'
]
datatype
=
torch
.
int64
# Broadcast data.
data
=
None
if
data_iterator
is
None
else
next
(
data_iterator
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
row_id
=
data_b
[
'row_id'
].
long
()
context
=
data_b
[
'context'
].
long
()
# TODO: make the context mask a binary one
context_mask
=
(
data_b
[
'context_mask'
]
<
0.5
)
context_types
=
data_b
[
'context_types'
].
long
()
context_pad_mask
=
data_b
[
'context_pad_mask'
].
long
()
return
row_id
,
context
,
context_mask
,
context_types
,
context_pad_mask
def
build_tokens_types_paddings_from_text
(
row
,
tokenizer
,
max_seq_length
):
"""Build token types and paddings, trim if needed, and pad if needed."""
title_ids
=
tokenizer
.
tokenize
(
row
[
'title'
])
context_ids
=
tokenizer
.
tokenize
(
row
[
'text'
])
# Appending the title of the context at front
extended_context_ids
=
title_ids
+
[
tokenizer
.
sep_id
]
+
context_ids
context_ids
,
context_types
,
context_pad_mask
=
\
build_tokens_types_paddings_from_ids
(
extended_context_ids
,
max_seq_length
,
tokenizer
.
cls
,
tokenizer
.
sep
,
tokenizer
.
pad
)
return
context_ids
,
context_types
,
context_pad_mask
# noinspection DuplicatedCode
def
build_tokens_types_paddings_from_ids
(
text_ids
,
max_seq_length
,
cls_id
,
sep_id
,
pad_id
):
"""Build token types and paddings, trim if needed, and pad if needed."""
enc_ids
=
[]
tokentypes_enc
=
[]
# [CLS].
enc_ids
.
append
(
cls_id
)
tokentypes_enc
.
append
(
0
)
# A.
len_src
=
len
(
text_ids
)
enc_ids
.
extend
(
text_ids
)
tokentypes_enc
.
extend
([
0
]
*
len_src
)
# Cap the size.
if
len
(
enc_ids
)
>
max_seq_length
-
1
:
enc_ids
=
enc_ids
[
0
:
max_seq_length
-
1
]
tokentypes_enc
=
tokentypes_enc
[
0
:
max_seq_length
-
1
]
# [SEP].
enc_ids
.
append
(
sep_id
)
tokentypes_enc
.
append
(
0
)
num_tokens_enc
=
len
(
enc_ids
)
# Padding.
padding_length
=
max_seq_length
-
len
(
enc_ids
)
if
padding_length
>
0
:
enc_ids
.
extend
([
pad_id
]
*
padding_length
)
tokentypes_enc
.
extend
([
pad_id
]
*
padding_length
)
pad_mask
=
([
1
]
*
num_tokens_enc
)
+
([
0
]
*
padding_length
)
pad_mask
=
np
.
array
(
pad_mask
,
dtype
=
np
.
int64
)
return
enc_ids
,
tokentypes_enc
,
pad_mask
def
build_sample
(
row_id
,
context_ids
,
context_types
,
context_pad_mask
):
"""Convert to numpy and return a sample consumed by the batch producer."""
context_ids
=
np
.
array
(
context_ids
,
dtype
=
np
.
int64
)
context_types
=
np
.
array
(
context_types
,
dtype
=
np
.
int64
)
context_mask
=
make_attention_mask
(
context_ids
,
context_ids
)
sample
=
({
'row_id'
:
row_id
,
'context'
:
context_ids
,
'context_mask'
:
context_mask
,
'context_types'
:
context_types
,
'context_pad_mask'
:
context_pad_mask
})
return
sample
class
OpenRetrievalEvidenceDataset
(
ABC
,
Dataset
):
"""Open Retrieval Evidence dataset class."""
def
__init__
(
self
,
task_name
,
dataset_name
,
datapath
,
tokenizer
,
max_seq_length
):
# Store inputs.
self
.
task_name
=
task_name
self
.
dataset_name
=
dataset_name
self
.
tokenizer
=
tokenizer
self
.
max_seq_length
=
max_seq_length
print_rank_0
(
' > building {} dataset for {}:'
.
format
(
self
.
task_name
,
self
.
dataset_name
))
# Process the files.
print_rank_0
(
datapath
)
self
.
samples
,
self
.
id2text
=
self
.
process_samples_from_single_path
(
datapath
)
args
=
get_args
()
if
args
.
sample_rate
<
1
:
# subsample
k
=
int
(
len
(
self
.
samples
)
*
args
.
sample_rate
)
self
.
samples
=
random
.
sample
(
self
.
samples
,
k
)
print_rank_0
(
' >> total number of samples: {}'
.
format
(
len
(
self
.
samples
)))
def
__len__
(
self
):
return
len
(
self
.
samples
)
def
__getitem__
(
self
,
idx
):
row
=
self
.
samples
[
idx
]
context_ids
,
context_types
,
context_pad_mask
=
\
build_tokens_types_paddings_from_text
(
row
,
self
.
tokenizer
,
self
.
max_seq_length
)
sample
=
build_sample
(
row
[
'doc_id'
],
context_ids
,
context_types
,
context_pad_mask
)
return
sample
@
staticmethod
def
process_samples_from_single_path
(
filename
):
print_rank_0
(
' > Processing {} ...'
.
format
(
filename
))
total
=
0
rows
=
[]
id2text
=
{}
with
open
(
filename
)
as
tsvfile
:
reader
=
csv
.
reader
(
tsvfile
,
delimiter
=
'
\t
'
)
next
(
reader
,
None
)
# skip the headers
for
row
in
reader
:
# file format: doc_id, doc_text, title
doc_id
=
int
(
row
[
0
])
text
=
row
[
1
]
title
=
row
[
2
]
rows
.
append
({
'doc_id'
:
doc_id
,
'text'
:
text
,
'title'
:
title
})
assert
doc_id
not
in
id2text
id2text
[
doc_id
]
=
(
text
,
title
)
total
+=
1
if
total
%
100000
==
0
:
print_rank_0
(
' > processed {} rows so far ...'
.
format
(
total
))
print_rank_0
(
' >> processed {} samples.'
.
format
(
len
(
rows
)))
return
rows
,
id2text
megatron/data/realm_index.py
View file @
41a64613
...
...
@@ -14,34 +14,36 @@ def detach(tensor):
return
tensor
.
detach
().
cpu
().
numpy
()
class
BlockData
(
object
):
"""Serializable data structure for holding data for blocks -- embeddings and necessary metadata for REALM"""
def
__init__
(
self
,
block_data_path
=
None
,
load_from_path
=
True
,
rank
=
None
):
class
OpenRetreivalDataStore
(
object
):
"""
Serializable data structure for holding data for blocks --
embeddings and necessary metadata for Retriever
"""
def
__init__
(
self
,
embedding_path
=
None
,
load_from_path
=
True
,
rank
=
None
):
self
.
embed_data
=
dict
()
self
.
meta_data
=
dict
()
if
block_data_path
is
None
:
if
embedding_path
is
None
:
args
=
get_args
()
block_data
_path
=
args
.
block_data
_path
embedding
_path
=
args
.
embedding
_path
rank
=
args
.
rank
self
.
block_data_path
=
block_data
_path
self
.
embedding_path
=
embedding
_path
self
.
rank
=
rank
if
load_from_path
:
self
.
load_from_file
()
block_data_name
=
os
.
path
.
splitext
(
self
.
block_data
_path
)[
0
]
block_data_name
=
os
.
path
.
splitext
(
self
.
embedding
_path
)[
0
]
self
.
temp_dir_name
=
block_data_name
+
'_tmp'
def
state
(
self
):
return
{
'embed_data'
:
self
.
embed_data
,
'meta_data'
:
self
.
meta_data
,
}
def
clear
(
self
):
"""Clear the embedding data structures to save memory.
The metadata ends up getting used, and is also much smaller in dimensionality
so it isn't really worth clearing.
"""
Clear the embedding data structures to save memory.
The metadata ends up getting used, and is also much smaller in
dimensionality so it isn't really worth clearing.
"""
self
.
embed_data
=
dict
()
...
...
@@ -50,38 +52,39 @@ class BlockData(object):
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
"
\n
> Unpickling BlockData"
,
flush
=
True
)
state_dict
=
pickle
.
load
(
open
(
self
.
block_data
_path
,
'rb'
))
state_dict
=
pickle
.
load
(
open
(
self
.
embedding
_path
,
'rb'
))
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
">> Finished unpickling BlockData
\n
"
,
flush
=
True
)
self
.
embed_data
=
state_dict
[
'embed_data'
]
self
.
meta_data
=
state_dict
[
'meta_data'
]
def
add_block_data
(
self
,
block_indices
,
block_embeds
,
block_metas
,
allow_overwrite
=
False
):
"""Add data for set of blocks
:param block_indices: 1D array of unique int ids for the blocks
def
add_block_data
(
self
,
row_id
,
block_embeds
,
allow_overwrite
=
False
):
"""
Add data for set of blocks
:param row_id: 1D array of unique int ids for the blocks
:param block_embeds: 2D array of embeddings of the blocks
:param block_metas: 2D array of metadata for the blocks.
In the case of REALM this will be [start_idx, end_idx, doc_idx]
In the case of retriever this will be [start_idx, end_idx, doc_idx]
"""
for
idx
,
embed
,
meta
in
zip
(
block_indices
,
block_embeds
,
block_meta
s
):
for
idx
,
embed
in
zip
(
row_id
,
block_embed
s
):
if
not
allow_overwrite
and
idx
in
self
.
embed_data
:
raise
ValueError
(
"Unexpectedly tried to overwrite block data"
)
self
.
embed_data
[
idx
]
=
np
.
float16
(
embed
)
self
.
meta_data
[
idx
]
=
meta
def
save_shard
(
self
):
"""Save the block data that was created this in this process"""
"""
Save the block data that was created this in this process
"""
if
not
os
.
path
.
isdir
(
self
.
temp_dir_name
):
os
.
makedirs
(
self
.
temp_dir_name
,
exist_ok
=
True
)
# save the data for each shard
with
open
(
'{}/{}.pkl'
.
format
(
self
.
temp_dir_name
,
self
.
rank
),
'wb'
)
as
data_file
:
pickle
.
dump
(
self
.
state
(),
data_file
)
with
open
(
'{}/{}.pkl'
.
format
(
self
.
temp_dir_name
,
self
.
rank
),
'wb'
)
\
as
writer
:
pickle
.
dump
(
self
.
state
(),
writer
)
def
merge_shards_and_save
(
self
):
"""
Combine all the shards made using
self.
save_shard
()"""
#
Combine all the shards made using save_shard
shard_names
=
os
.
listdir
(
self
.
temp_dir_name
)
seen_own_shard
=
False
...
...
@@ -96,15 +99,15 @@ class BlockData(object):
old_size
=
len
(
self
.
embed_data
)
shard_size
=
len
(
data
[
'embed_data'
])
# add the shard's data and check to make sure there is no overlap
# add the shard's data and check to make sure there
# is no overlap
self
.
embed_data
.
update
(
data
[
'embed_data'
])
self
.
meta_data
.
update
(
data
[
'meta_data'
])
assert
len
(
self
.
embed_data
)
==
old_size
+
shard_size
assert
seen_own_shard
# save the consolidated shards and remove temporary directory
with
open
(
self
.
block_data
_path
,
'wb'
)
as
final_file
:
with
open
(
self
.
embedding
_path
,
'wb'
)
as
final_file
:
pickle
.
dump
(
self
.
state
(),
final_file
)
shutil
.
rmtree
(
self
.
temp_dir_name
,
ignore_errors
=
True
)
...
...
megatron/indexer.py
View file @
41a64613
import
sys
import
torch
import
torch.distributed
as
dist
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron.checkpointing
import
load_
ict
_checkpoint
from
megatron.data.
ict
_dataset
import
get_
ict
_dataset
from
megatron.data.
realm
_dataset
_utils
import
get_o
ne_epoch_dataloader
from
megatron.data.
realm_index
import
d
et
ach
,
BlockData
from
megatron.data.realm_
dataset_utils
import
get_ict_batch
from
megatron.model.
realm
_model
import
general_ict
_model_provider
from
megatron.checkpointing
import
load_
biencoder
_checkpoint
from
megatron.data.
orqa_wiki
_dataset
import
get_
open_retrieval_wiki
_dataset
from
megatron.data.
orqa_wiki
_dataset
import
get_o
pen_retrieval_batch
from
megatron.data.
biencoder_dataset_utils
import
g
et
_one_epoch_dataloader
from
megatron.data.realm_
index
import
detach
,
OpenRetreivalDataStore
from
megatron.model.
biencoder
_model
import
biencoder
_model_provider
from
megatron.training
import
get_model
class
IndexBuilder
(
object
):
"""Object for taking one pass over a dataset and creating a BlockData of its embeddings"""
"""
Object for taking one pass over a dataset and creating a BlockData of its
embeddings
"""
def
__init__
(
self
):
args
=
get_args
()
self
.
model
=
None
self
.
dataloader
=
None
self
.
block_data
=
None
self
.
evidence_embedder_obj
=
None
self
.
biencoder_shared_query_context_model
=
\
args
.
biencoder_shared_query_context_model
# need to know whether we're using a REALM checkpoint (args.load) or ICT checkpoint
# need to know whether we're using a REALM checkpoint (args.load)
# or ICT checkpoint
assert
not
(
args
.
load
and
args
.
ict_load
)
self
.
using_realm_chkpt
=
args
.
ict_load
is
None
#
self.using_realm_chkpt = args.ict_load is None
self
.
log_interval
=
args
.
indexer_log_interval
self
.
batch_size
=
args
.
indexer_batch_size
...
...
@@ -33,59 +40,88 @@ class IndexBuilder(object):
self
.
iteration
=
self
.
total_processed
=
0
def
load_attributes
(
self
):
"""Load the necessary attributes: model, dataloader and empty BlockData"""
model
=
get_model
(
lambda
:
general_ict_model_provider
(
only_block_model
=
True
))
self
.
model
=
load_ict_checkpoint
(
model
,
only_block_model
=
True
,
from_realm_chkpt
=
self
.
using_realm_chkpt
)
self
.
model
.
eval
()
self
.
dataset
=
get_ict_dataset
()
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
self
.
dataset
,
self
.
batch_size
))
self
.
block_data
=
BlockData
(
load_from_path
=
False
)
"""
Load the necessary attributes: model, dataloader and empty BlockData
"""
only_context_model
=
True
if
self
.
biencoder_shared_query_context_model
:
only_context_model
=
False
model
=
get_model
(
lambda
:
biencoder_model_provider
(
only_context_model
\
=
only_context_model
,
biencoder_shared_query_context_model
=
\
self
.
biencoder_shared_query_context_model
))
self
.
model
=
load_biencoder_checkpoint
(
model
,
only_context_model
=
only_context_model
)
assert
len
(
self
.
model
)
==
1
self
.
model
[
0
].
eval
()
self
.
dataset
=
get_open_retrieval_wiki_dataset
()
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
self
.
dataset
,
\
self
.
batch_size
))
self
.
evidence_embedder_obj
=
OpenRetreivalDataStore
(
\
load_from_path
=
False
)
def
track_and_report_progress
(
self
,
batch_size
):
"""Utility function for tracking progress"""
"""
Utility function for tracking progress
"""
self
.
iteration
+=
1
self
.
total_processed
+=
batch_size
*
self
.
num_total_builders
if
self
.
is_main_builder
and
self
.
iteration
%
self
.
log_interval
==
0
:
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
self
.
iteration
,
self
.
total_processed
),
flush
=
True
)
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
self
.
iteration
,
self
.
total_processed
),
flush
=
True
)
def
build_and_save_index
(
self
):
"""Goes through one epoch of the dataloader and adds all data to this instance's BlockData.
"""
Goes through one epoch of the dataloader and adds all data to this
instance's BlockData.
The copy of BlockData is saved as a shard, which when run in a distributed setting will be
consolidated by the rank 0 process and saved as a final pickled BlockData.
The copy of BlockData is saved as a shard, which when run in a
distributed setting will be consolidated by the rank 0 process
and saved as a final pickled BlockData.
"""
assert
len
(
self
.
model
)
==
1
unwrapped_model
=
self
.
model
[
0
]
while
not
hasattr
(
unwrapped_model
,
'embed_text'
):
unwrapped_model
=
unwrapped_model
.
module
while
True
:
try
:
# batch also has query_tokens and query_pad_data
_
,
_
,
block_tokens
,
block_pad_mask
,
block_sample_data
=
get_ict_batch
(
self
.
dataloader
)
row_id
,
context_tokens
,
context_mask
,
context_types
,
\
context_pad_mask
=
get_open_retrieval_batch
(
\
self
.
dataloader
)
except
(
StopIteration
,
IndexError
):
break
unwrapped_model
=
self
.
model
while
not
hasattr
(
unwrapped_model
,
'embed_block'
):
unwrapped_model
=
unwrapped_model
.
module
# TODO: can we add with torch.no_grad() to reduce memory usage
# detach, separate fields and add to BlockData
block_logits
=
detach
(
unwrapped_model
.
embed_block
(
block_tokens
,
block_pad_mask
))
detached_data
=
detach
(
block_sample_data
)
# block_sample_data is a 2D array [batch x 4]
# with columns [start_idx, end_idx, doc_idx, block_idx] same as class BlockSampleData
block_indices
=
detached_data
[:,
3
]
block_metas
=
detached_data
[:,
:
3
]
self
.
block_data
.
add_block_data
(
block_indices
,
block_logits
,
block_metas
)
self
.
track_and_report_progress
(
batch_size
=
block_tokens
.
shape
[
0
])
#
This process signals to finalize its shard and then synchronize with
the other processes
self
.
block_data
.
save_shard
()
assert
context_mask
.
dtype
==
torch
.
bool
context_logits
=
unwrapped_model
.
embed_text
(
unwrapped_model
.
context_model
,
context_tokens
,
context_mask
,
context_types
)
context_logits
=
detach
(
context_logits
)
row_id
=
detach
(
row_id
)
self
.
evidence_embedder_obj
.
add_block_data
(
row_id
,
context_logits
)
self
.
track_and_report_progress
(
batch_size
=
len
(
row_id
)
)
# This process signals to finalize its shard and then synchronize with
# the other processes
self
.
evidence_embedder_obj
.
save_shard
()
torch
.
distributed
.
barrier
()
del
self
.
model
# rank 0 process builds the final copy
if
self
.
is_main_builder
:
self
.
block_data
.
merge_shards_and_save
()
self
.
evidence_embedder_obj
.
merge_shards_and_save
()
# make sure that every single piece of data was embedded
assert
len
(
self
.
block_data
.
embed_data
)
==
len
(
self
.
dataset
)
self
.
block_data
.
clear
()
assert
len
(
self
.
evidence_embedder_obj
.
embed_data
)
==
\
len
(
self
.
dataset
)
self
.
evidence_embedder_obj
.
clear
()
# complete building the final copy
torch
.
distributed
.
barrier
()
megatron/initialize.py
View file @
41a64613
...
...
@@ -133,7 +133,8 @@ def _initialize_distributed():
print
(
'model parallel is already initialized'
)
else
:
mpu
.
initialize_model_parallel
(
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
)
args
.
pipeline_model_parallel_size
,
args
.
virtual_pipeline_model_parallel_size
)
def
_init_autoresume
():
...
...
megatron/model/__init__.py
View file @
41a64613
...
...
@@ -34,13 +34,11 @@ from .bert_model import (BertModel,
BertModelFirstStage
,
BertModelIntermediateStage
,
BertModelLastStage
)
from
.realm_model
import
ICTBertModel
from
.gpt_model
import
(
GPTModel
,
GPTModelFirstStage
,
GPTModelIntermediateStage
,
GPTModelLastStage
)
from
.language_model
import
get_language_model
from
.module
import
FP16Module
from
.realm_model
import
ICTBertModel
megatron/model/biencoder_model.py
0 → 100644
View file @
41a64613
import
os
import
torch
import
sys
from
megatron
import
get_args
,
print_rank_0
from
megatron.checkpointing
import
fix_query_key_value_ordering
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
from
megatron.checkpointing
import
get_checkpoint_name
from
megatron
import
mpu
,
get_tokenizer
from
megatron.model.bert_model
import
bert_position_ids
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.language_model
import
get_language_model
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
from
megatron.model.utils
import
scaled_init_method_normal
from
.module
import
MegatronModule
def
biencoder_model_provider
(
only_query_model
=
False
,
only_context_model
=
False
,
biencoder_shared_query_context_model
=
False
):
"""Build the model."""
args
=
get_args
()
assert
mpu
.
get_tensor_model_parallel_world_size
()
==
1
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
==
1
,
\
"Model parallel size > 1 not supported for ICT"
print_rank_0
(
'building BiEncoderModel...'
)
# simpler to just keep using 2 tokentypes since
# the LM we initialize with has 2 tokentypes
model
=
BiEncoderModel
(
num_tokentypes
=
2
,
parallel_output
=
False
,
only_query_model
=
only_query_model
,
only_context_model
=
only_context_model
,
biencoder_shared_query_context_model
=
\
biencoder_shared_query_context_model
)
return
model
class
BiEncoderModel
(
MegatronModule
):
"""Bert-based module for Biencoder model."""
def
__init__
(
self
,
num_tokentypes
=
1
,
parallel_output
=
True
,
only_query_model
=
False
,
only_context_model
=
False
,
biencoder_shared_query_context_model
=
False
):
super
(
BiEncoderModel
,
self
).
__init__
()
args
=
get_args
()
bert_kwargs
=
dict
(
num_tokentypes
=
num_tokentypes
,
parallel_output
=
parallel_output
)
self
.
biencoder_shared_query_context_model
=
\
biencoder_shared_query_context_model
assert
not
(
only_context_model
and
only_query_model
)
self
.
use_context_model
=
not
only_query_model
self
.
use_query_model
=
not
only_context_model
self
.
biencoder_projection_dim
=
args
.
biencoder_projection_dim
if
self
.
biencoder_shared_query_context_model
:
self
.
model
=
PretrainedBertModel
(
**
bert_kwargs
)
self
.
_model_key
=
'shared_model'
self
.
query_model
,
self
.
context_model
=
self
.
model
,
self
.
model
else
:
if
self
.
use_query_model
:
# this model embeds (pseudo-)queries - Embed_input in the paper
self
.
query_model
=
PretrainedBertModel
(
**
bert_kwargs
)
self
.
_query_key
=
'query_model'
if
self
.
use_context_model
:
# this model embeds evidence blocks - Embed_doc in the paper
self
.
context_model
=
PretrainedBertModel
(
**
bert_kwargs
)
self
.
_context_key
=
'context_model'
def
forward
(
self
,
query_tokens
,
query_attention_mask
,
query_types
,
context_tokens
,
context_attention_mask
,
context_types
):
"""Run a forward pass for each of the models and
return the respective embeddings."""
if
self
.
use_query_model
:
query_logits
=
self
.
embed_text
(
self
.
query_model
,
query_tokens
,
query_attention_mask
,
query_types
)
else
:
raise
ValueError
(
"Cannot embed query without the query model."
)
if
self
.
use_context_model
:
context_logits
=
self
.
embed_text
(
self
.
context_model
,
context_tokens
,
context_attention_mask
,
context_types
)
else
:
raise
ValueError
(
"Cannot embed block without the block model."
)
return
query_logits
,
context_logits
@
staticmethod
def
embed_text
(
model
,
tokens
,
attention_mask
,
token_types
):
"""Embed a batch of tokens using the model"""
logits
=
model
(
tokens
,
attention_mask
,
token_types
)
return
logits
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
\
prefix
=
''
,
keep_vars
=
False
):
"""Save dict with state dicts of each of the models."""
state_dict_
=
{}
if
self
.
biencoder_shared_query_context_model
:
state_dict_
[
self
.
_model_key
]
=
\
self
.
model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
else
:
if
self
.
use_query_model
:
state_dict_
[
self
.
_query_key
]
=
\
self
.
query_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
use_context_model
:
state_dict_
[
self
.
_context_key
]
=
\
self
.
context_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Load the state dicts of each of the models"""
if
self
.
biencoder_shared_query_context_model
:
print_rank_0
(
"Loading shared query-context model"
)
self
.
model
.
load_state_dict
(
state_dict
[
self
.
_model_key
],
\
strict
=
strict
)
else
:
if
self
.
use_query_model
:
print_rank_0
(
"Loading query model"
)
self
.
query_model
.
load_state_dict
(
\
state_dict
[
self
.
_query_key
],
strict
=
strict
)
if
self
.
use_context_model
:
print_rank_0
(
"Loading context model"
)
self
.
context_model
.
load_state_dict
(
\
state_dict
[
self
.
_context_key
],
strict
=
strict
)
def
init_state_dict_from_bert
(
self
):
"""Initialize the state from a pretrained BERT model
on iteration zero of ICT pretraining"""
args
=
get_args
()
if
args
.
bert_load
is
None
:
print_rank_0
(
"bert-load argument is None"
)
return
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
bert_load
)
if
not
os
.
path
.
isfile
(
tracker_filename
):
raise
FileNotFoundError
(
"Could not find BERT checkpoint"
)
with
open
(
tracker_filename
,
'r'
)
as
f
:
iteration
=
int
(
f
.
read
().
strip
())
assert
iteration
>
0
checkpoint_name
=
get_checkpoint_name
(
args
.
bert_load
,
iteration
,
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading BERT checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
# Load the checkpoint.
try
:
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
except
ModuleNotFoundError
:
from
megatron.fp16_deprecated
import
loss_scaler
# For backward compatibility.
print_rank_0
(
' > deserializing using the old code structure ...'
)
sys
.
modules
[
'fp16.loss_scaler'
]
=
sys
.
modules
[
'megatron.fp16_deprecated.loss_scaler'
]
sys
.
modules
[
'megatron.fp16.loss_scaler'
]
=
sys
.
modules
[
'megatron.fp16_deprecated.loss_scaler'
]
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
sys
.
modules
.
pop
(
'fp16.loss_scaler'
,
None
)
sys
.
modules
.
pop
(
'megatron.fp16.loss_scaler'
,
None
)
except
BaseException
:
print_rank_0
(
'could not load the BERT checkpoint'
)
sys
.
exit
()
checkpoint_version
=
state_dict
.
get
(
'checkpoint_version'
,
0
)
# load the LM state dict into each model
model_dict
=
state_dict
[
'model'
][
'language_model'
]
if
self
.
biencoder_shared_query_context_model
:
self
.
model
.
language_model
.
load_state_dict
(
model_dict
)
fix_query_key_value_ordering
(
self
.
model
,
checkpoint_version
)
else
:
if
self
.
use_query_model
:
self
.
query_model
.
language_model
.
load_state_dict
(
model_dict
)
# give each model the same ict_head to begin with as well
if
self
.
biencoder_projection_dim
>
0
:
query_proj_state_dict
=
\
self
.
state_dict_for_save_checkpoint
()
\
[
self
.
_query_key
][
'projection_enc'
]
fix_query_key_value_ordering
(
self
.
query_model
,
checkpoint_version
)
if
self
.
use_context_model
:
self
.
context_model
.
language_model
.
load_state_dict
(
model_dict
)
if
self
.
query_model
is
not
None
and
\
self
.
biencoder_projection_dim
>
0
:
self
.
context_model
.
projection_enc
.
load_state_dict
\
(
query_proj_state_dict
)
fix_query_key_value_ordering
(
self
.
context_model
,
checkpoint_version
)
class
PretrainedBertModel
(
MegatronModule
):
"""BERT-based encoder for queries or contexts used for
learned information retrieval."""
def
__init__
(
self
,
num_tokentypes
=
2
,
parallel_output
=
True
):
super
(
PretrainedBertModel
,
self
).
__init__
()
args
=
get_args
()
tokenizer
=
get_tokenizer
()
self
.
pad_id
=
tokenizer
.
pad
self
.
biencoder_projection_dim
=
args
.
biencoder_projection_dim
self
.
parallel_output
=
parallel_output
init_method
=
init_method_normal
(
args
.
init_method_std
)
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
num_tokentypes
=
num_tokentypes
,
add_pooler
=
False
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
)
if
args
.
biencoder_projection_dim
>
0
:
self
.
projection_enc
=
get_linear_layer
(
args
.
hidden_size
,
args
.
biencoder_projection_dim
,
init_method
)
self
.
_projection_enc_key
=
'projection_enc'
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
extended_attention_mask
=
attention_mask
.
unsqueeze
(
1
)
#extended_attention_mask = bert_extended_attention_mask(attention_mask)
position_ids
=
bert_position_ids
(
input_ids
)
lm_output
=
self
.
language_model
(
input_ids
,
position_ids
,
extended_attention_mask
,
tokentype_ids
=
tokentype_ids
)
# This mask will be used in average-pooling and max-pooling
pool_mask
=
(
input_ids
==
self
.
pad_id
).
unsqueeze
(
2
)
# Taking the representation of the [CLS] token of BERT
pooled_output
=
lm_output
[:,
0
,
:]
# Converting to float16 dtype
pooled_output
=
pooled_output
.
to
(
lm_output
.
dtype
)
# Output.
if
self
.
biencoder_projection_dim
:
pooled_output
=
self
.
projection_enc
(
pooled_output
)
return
pooled_output
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_
=
{}
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
biencoder_projection_dim
>
0
:
state_dict_
[
self
.
_projection_enc_key
]
=
\
self
.
projection_enc
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
print_rank_0
(
"loading BERT weights"
)
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
if
self
.
biencoder_projection_dim
>
0
:
print_rank_0
(
"loading projection head weights"
)
self
.
projection_enc
.
load_state_dict
(
state_dict
[
self
.
_projection_enc_key
],
strict
=
strict
)
megatron/model/module.py
View file @
41a64613
...
...
@@ -50,9 +50,9 @@ class MegatronModule(torch.nn.Module):
def
word_embeddings_weight
(
self
):
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
return
self
.
language_model
.
embedding
.
word_embeddings
.
weight
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
not
self
.
share_word_embeddings
:
raise
Exception
(
'word_embeddings_weight() called for last '
'stage, but share_word_embeddings is false'
)
...
...
megatron/model/transformer.py
View file @
41a64613
...
...
@@ -552,7 +552,27 @@ class ParallelTransformer(MegatronModule):
layer_number
,
layer_type
=
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
)
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
'num_layers_per_stage must be divisible by '
\
'virtual_pipeline_model_parallel_size'
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self
.
num_layers
=
self
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset
=
mpu
.
get_virtual_pipeline_model_parallel_rank
()
*
(
args
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
)
+
\
(
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
)
else
:
# Each stage gets a contiguous set of layers.
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
...
...
megatron/mpu/__init__.py
View file @
41a64613
...
...
@@ -38,6 +38,7 @@ from .initialize import get_pipeline_model_parallel_next_rank
from
.initialize
import
get_pipeline_model_parallel_prev_rank
from
.initialize
import
get_tensor_model_parallel_world_size
,
set_tensor_model_parallel_world_size
from
.initialize
import
get_pipeline_model_parallel_world_size
,
set_pipeline_model_parallel_world_size
from
.initialize
import
get_virtual_pipeline_model_parallel_rank
,
set_virtual_pipeline_model_parallel_rank
from
.initialize
import
initialize_model_parallel
from
.initialize
import
model_parallel_is_initialized
...
...
@@ -58,6 +59,8 @@ from .random import get_cuda_rng_tracker
from
.random
import
init_checkpointed_activations_memory_buffer
from
.random
import
model_parallel_cuda_manual_seed
from
.random
import
reset_checkpointed_activations_memory_buffer
from
.random
import
gather_split_1d_tensor
from
.random
import
split_tensor_into_1d_equal_chunks
from
.utils
import
divide
from
.utils
import
split_tensor_along_last_dim
megatron/mpu/initialize.py
View file @
41a64613
...
...
@@ -32,6 +32,9 @@ _EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
# These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
=
None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
...
...
@@ -48,7 +51,8 @@ def is_unitialized():
def
initialize_model_parallel
(
tensor_model_parallel_size_
=
1
,
pipeline_model_parallel_size_
=
1
):
pipeline_model_parallel_size_
=
1
,
virtual_pipeline_model_parallel_size_
=
None
):
"""
Initialize model data parallel groups.
...
...
@@ -91,6 +95,12 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
num_pipeline_model_parallel_groups
=
world_size
//
pipeline_model_parallel_size
num_data_parallel_groups
=
world_size
//
data_parallel_size
if
virtual_pipeline_model_parallel_size_
is
not
None
:
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
virtual_pipeline_model_parallel_size_
rank
=
torch
.
distributed
.
get_rank
()
# Build the data-parallel groups.
...
...
@@ -258,17 +268,46 @@ def get_pipeline_model_parallel_rank():
return
torch
.
distributed
.
get_rank
(
group
=
get_pipeline_model_parallel_group
())
def
is_pipeline_first_stage
():
def
is_pipeline_first_stage
(
ignore_virtual
=
False
):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
if
get_virtual_pipeline_model_parallel_world_size
()
is
not
None
and
\
get_virtual_pipeline_model_parallel_rank
()
!=
0
:
return
False
return
get_pipeline_model_parallel_rank
()
==
0
def
is_pipeline_last_stage
():
def
is_pipeline_last_stage
(
ignore_virtual
=
False
):
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
virtual_pipeline_model_parallel_world_size
=
\
get_virtual_pipeline_model_parallel_world_size
()
if
virtual_pipeline_model_parallel_world_size
is
not
None
and
\
get_virtual_pipeline_model_parallel_rank
()
!=
(
virtual_pipeline_model_parallel_world_size
-
1
):
return
False
return
get_pipeline_model_parallel_rank
()
==
(
get_pipeline_model_parallel_world_size
()
-
1
)
def
get_virtual_pipeline_model_parallel_rank
():
"""Return the virtual pipeline-parallel rank."""
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
return
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
def
set_virtual_pipeline_model_parallel_rank
(
rank
):
"""Set the virtual pipeline-parallel rank."""
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
rank
def
get_virtual_pipeline_model_parallel_world_size
():
"""Return the virtual pipeline-parallel world size."""
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
def
get_tensor_model_parallel_src_rank
():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
...
...
@@ -276,11 +315,13 @@ def get_tensor_model_parallel_src_rank():
local_world_size
=
get_tensor_model_parallel_world_size
()
return
(
global_rank
//
local_world_size
)
*
local_world_size
def
get_pipeline_model_parallel_first_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
"Pipeline parallel group is not initialized"
return
_PIPELINE_GLOBAL_RANKS
[
0
]
def
get_pipeline_model_parallel_last_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
"Pipeline parallel group is not initialized"
...
...
@@ -294,6 +335,7 @@ def get_pipeline_model_parallel_next_rank():
world_size
=
get_pipeline_model_parallel_world_size
()
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
+
1
)
%
world_size
]
def
get_pipeline_model_parallel_prev_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
"Pipeline parallel group is not initialized"
...
...
@@ -301,6 +343,7 @@ def get_pipeline_model_parallel_prev_rank():
world_size
=
get_pipeline_model_parallel_world_size
()
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
-
1
)
%
world_size
]
def
get_data_parallel_world_size
():
"""Return world size for the data parallel group."""
return
torch
.
distributed
.
get_world_size
(
group
=
get_data_parallel_group
())
...
...
megatron/optimizer/__init__.py
View file @
41a64613
...
...
@@ -23,7 +23,7 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from
.optimizer
import
FP16OptimizerWithFP16Params
,
FP32Optimizer
def
_get_params_for_weight_decay_optimization
(
module
):
def
_get_params_for_weight_decay_optimization
(
module
s
):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
"""
...
...
@@ -32,18 +32,19 @@ def _get_params_for_weight_decay_optimization(module):
weight_decay_params
=
{
'params'
:
[]}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
for
module_
in
module
.
modules
():
if
isinstance
(
module_
,
LayerNorm
):
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
if
p
is
not
None
])
else
:
weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
!=
'bias'
])
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
==
'bias'
])
for
module
in
modules
:
for
module_
in
module
.
modules
():
if
isinstance
(
module_
,
LayerNorm
):
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
if
p
is
not
None
])
else
:
weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
!=
'bias'
])
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
==
'bias'
])
return
weight_decay_params
,
no_weight_decay_params
...
...
megatron/p2p_communication.py
0 → 100644
View file @
41a64613
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
reduce
import
operator
import
torch
from
megatron
import
get_args
from
megatron
import
mpu
def
_communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
use_ring_exchange
=
False
):
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
Takes the following arguments:
tensor_send_next: tensor to send to next rank (no tensor sent if
set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if
set to None).
recv_prev: boolean for whether tensor should be received from
previous rank.
recv_next: boolean for whether tensor should be received from
next rank.
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used.
Returns:
(tensor_recv_prev, tensor_recv_next)
"""
args
=
get_args
()
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev
=
None
tensor_recv_next
=
None
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
if
args
.
scatter_gather_tensors_in_pipeline
:
tensor_chunk_shape
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
//
\
mpu
.
get_tensor_model_parallel_world_size
()
else
:
tensor_chunk_shape
=
tensor_shape
dtype
=
args
.
params_dtype
if
args
.
fp32_residual_connection
:
dtype
=
torch
.
float
if
recv_prev
:
tensor_recv_prev
=
torch
.
empty
(
tensor_chunk_shape
,
requires_grad
=
True
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
if
recv_next
:
tensor_recv_next
=
torch
.
empty
(
tensor_chunk_shape
,
requires_grad
=
True
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
# Split tensor into smaller chunks if using scatter-gather optimization.
if
args
.
scatter_gather_tensors_in_pipeline
:
if
tensor_send_next
is
not
None
:
tensor_send_next
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
if
tensor_send_prev
is
not
None
:
tensor_send_prev
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_prev
)
# Send tensors in both the forward and backward directions as appropriate.
if
use_ring_exchange
:
torch
.
distributed
.
ring_exchange
(
tensor_send_prev
=
tensor_send_prev
,
tensor_recv_prev
=
tensor_recv_prev
,
tensor_send_next
=
tensor_send_next
,
tensor_recv_next
=
tensor_recv_next
,
group
=
mpu
.
get_pipeline_model_parallel_group
())
else
:
ops
=
[]
if
tensor_send_prev
is
not
None
:
send_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
send_prev_op
)
if
tensor_recv_prev
is
not
None
:
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
recv_prev_op
)
if
tensor_send_next
is
not
None
:
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
send_next_op
)
if
tensor_recv_next
is
not
None
:
recv_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
recv_next_op
)
if
len
(
ops
)
>
0
:
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
for
req
in
reqs
:
req
.
wait
()
# To protect against race condition when using batch_isend_irecv().
torch
.
cuda
.
synchronize
()
# If using scatter-gather optimization, gather smaller chunks.
if
args
.
scatter_gather_tensors_in_pipeline
:
if
recv_prev
:
tensor_recv_prev
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_prev
).
view
(
tensor_shape
).
requires_grad_
()
if
recv_next
:
tensor_recv_next
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_next
).
view
(
tensor_shape
).
requires_grad_
()
return
tensor_recv_prev
,
tensor_recv_next
def
recv_forward
(
timers
=
None
):
"""Receive tensor from previous rank in pipeline (forward receive)."""
if
mpu
.
is_pipeline_first_stage
():
input_tensor
=
None
else
:
if
timers
is
not
None
:
timers
(
'forward-recv'
).
start
()
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
True
,
recv_next
=
False
)
if
timers
is
not
None
:
timers
(
'forward-recv'
).
stop
()
return
input_tensor
def
recv_backward
(
timers
=
None
):
"""Receive tensor from next rank in pipeline (backward receive)."""
if
mpu
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
else
:
if
timers
is
not
None
:
timers
(
'backward-recv'
).
start
()
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
True
)
if
timers
is
not
None
:
timers
(
'backward-recv'
).
stop
()
return
output_tensor_grad
def
send_forward
(
output_tensor
,
timers
=
None
):
"""Send tensor to next rank in pipeline (forward send)."""
if
not
mpu
.
is_pipeline_last_stage
():
if
timers
is
not
None
:
timers
(
'forward-send'
).
start
()
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
False
)
if
timers
is
not
None
:
timers
(
'forward-send'
).
stop
()
def
send_backward
(
input_tensor_grad
,
timers
=
None
):
"""Send tensor to previous rank in pipeline (backward send)."""
if
not
mpu
.
is_pipeline_first_stage
():
if
timers
is
not
None
:
timers
(
'backward-send'
).
start
()
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
False
)
if
timers
is
not
None
:
timers
(
'backward-send'
).
stop
()
def
send_forward_recv_backward
(
output_tensor
,
timers
=
None
):
"""Batched send and recv with next rank in pipeline."""
if
mpu
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
else
:
if
timers
is
not
None
:
timers
(
'forward-send-backward-recv'
).
start
()
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
True
)
if
timers
is
not
None
:
timers
(
'forward-send-backward-recv'
).
stop
()
return
output_tensor_grad
def
send_backward_recv_forward
(
input_tensor_grad
,
timers
=
None
):
"""Batched send and recv with previous rank in pipeline."""
if
mpu
.
is_pipeline_first_stage
():
input_tensor
=
None
else
:
if
timers
is
not
None
:
timers
(
'backward-send-forward-recv'
).
start
()
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
True
,
recv_next
=
False
)
if
timers
is
not
None
:
timers
(
'backward-send-forward-recv'
).
stop
()
return
input_tensor
def
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
timers
=
None
):
"""Batched recv from previous rank and send to next rank in pipeline."""
if
timers
is
not
None
:
timers
(
'forward-send-forward-recv'
).
start
()
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
recv_prev
,
recv_next
=
False
)
if
timers
is
not
None
:
timers
(
'forward-send-forward-recv'
).
stop
()
return
input_tensor
def
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
,
timers
=
None
):
"""Batched recv from next rank and send to previous rank in pipeline."""
if
timers
is
not
None
:
timers
(
'backward-send-backward-recv'
).
start
()
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
recv_next
)
if
timers
is
not
None
:
timers
(
'backward-send-backward-recv'
).
stop
()
return
output_tensor_grad
def
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
,
recv_next
,
timers
=
None
):
"""Batched send and recv with previous and next ranks in pipeline."""
if
timers
is
not
None
:
timers
(
'forward-backward-send-forward-backward-recv'
).
start
()
input_tensor
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
)
if
timers
is
not
None
:
timers
(
'forward-backward-send-forward-backward-recv'
).
stop
()
return
input_tensor
,
output_tensor_grad
megatron/schedules.py
0 → 100644
View file @
41a64613
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
contextlib
import
contextmanager
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
get_args
from
megatron
import
get_num_microbatches
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
p2p_communication
def
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
):
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
Returns output tensor."""
timers
=
get_timers
()
timers
(
'forward-compute'
).
start
()
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
if
mpu
.
is_pipeline_last_stage
():
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
/
get_num_microbatches
()
losses_reduced
.
append
(
loss_reduced
)
timers
(
'forward-compute'
).
stop
()
return
output_tensor
def
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage)."""
args
=
get_args
()
timers
=
get_timers
()
timers
(
'backward-compute'
).
start
()
# Retain the grad on the input_tensor.
if
input_tensor
is
not
None
:
input_tensor
.
retain_grad
()
# Backward pass.
if
output_tensor_grad
is
None
:
output_tensor
=
optimizer
.
scale_loss
(
output_tensor
)
torch
.
autograd
.
backward
(
output_tensor
,
grad_tensors
=
output_tensor_grad
)
# Collect the grad of the input_tensor.
input_tensor_grad
=
None
if
input_tensor
is
not
None
:
input_tensor_grad
=
input_tensor
.
grad
timers
(
'backward-compute'
).
stop
()
return
input_tensor_grad
@
contextmanager
def
dummy_handler
():
try
:
yield
finally
:
pass
def
forward_backward_no_pipelining
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
):
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
Returns dictionary with losses."""
assert
len
(
model
)
==
1
model
=
model
[
0
]
context_handler
=
dummy_handler
if
isinstance
(
model
,
torchDDP
):
context_handler
=
model
.
no_sync
losses_reduced
=
[]
input_tensor
,
output_tensor_grad
=
None
,
None
with
context_handler
():
for
i
in
range
(
get_num_microbatches
()
-
1
):
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
if
not
forward_only
:
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
if
not
forward_only
:
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
return
losses_reduced
def
forward_backward_pipelining_with_interleaving
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
input_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
output_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
losses_reduced
=
[]
if
not
forward_only
:
output_tensor_grads
=
[[]
for
_
in
range
(
len
(
model
))]
pipeline_parallel_size
=
mpu
.
get_pipeline_model_parallel_world_size
()
pipeline_parallel_rank
=
mpu
.
get_pipeline_model_parallel_rank
()
# Compute number of warmup and remaining microbatches.
num_model_chunks
=
len
(
model
)
num_microbatches
=
get_num_microbatches
()
*
num_model_chunks
all_warmup_microbatches
=
False
if
forward_only
:
num_warmup_microbatches
=
num_microbatches
else
:
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
# all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B).
if
get_num_microbatches
()
==
pipeline_parallel_size
:
num_warmup_microbatches
=
num_microbatches
all_warmup_microbatches
=
True
else
:
num_warmup_microbatches
=
\
(
pipeline_parallel_size
-
pipeline_parallel_rank
-
1
)
*
2
num_warmup_microbatches
+=
(
num_model_chunks
-
1
)
*
pipeline_parallel_size
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
=
\
num_microbatches
-
num_warmup_microbatches
def
get_model_chunk_id
(
microbatch_id
,
forward
):
"""Helper method to get the model chunk ID given the iteration number."""
microbatch_id_in_group
=
microbatch_id
%
(
pipeline_parallel_size
*
num_model_chunks
)
model_chunk_id
=
microbatch_id_in_group
//
pipeline_parallel_size
if
not
forward
:
model_chunk_id
=
(
num_model_chunks
-
model_chunk_id
-
1
)
return
model_chunk_id
def
forward_step_helper
(
microbatch_id
):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
True
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
if
mpu
.
is_pipeline_first_stage
():
if
len
(
input_tensors
[
model_chunk_id
])
==
\
len
(
output_tensors
[
model_chunk_id
]):
input_tensors
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
][
-
1
]
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
[
model_chunk_id
],
model
[
model_chunk_id
],
input_tensor
,
losses_reduced
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
return
output_tensor
def
backward_step_helper
(
microbatch_id
):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
False
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
if
mpu
.
is_pipeline_last_stage
():
if
len
(
output_tensor_grads
[
model_chunk_id
])
==
0
:
output_tensor_grads
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor
=
output_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor_grad
=
output_tensor_grads
[
model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
\
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
return
input_tensor_grad
# Run warmup forward passes.
mpu
.
set_virtual_pipeline_model_parallel_rank
(
0
)
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
timers
))
for
k
in
range
(
num_warmup_microbatches
):
output_tensor
=
forward_step_helper
(
k
)
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
True
)
recv_prev
=
True
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
if
next_forward_model_chunk_id
==
0
:
recv_prev
=
False
if
k
==
(
num_microbatches
-
1
):
recv_prev
=
False
# Don't send tensor downstream if on last stage.
if
mpu
.
is_pipeline_last_stage
():
output_tensor
=
None
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if
k
==
(
num_warmup_microbatches
-
1
)
and
not
forward_only
and
\
not
all_warmup_microbatches
:
input_tensor_grad
=
None
recv_next
=
True
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
recv_next
=
False
input_tensor
,
output_tensor_grad
=
\
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
timers
=
timers
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
else
:
input_tensor
=
\
p2p_communication
.
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
timers
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
# Run 1F1B in steady state.
for
k
in
range
(
num_microbatches_remaining
):
# Forward pass.
forward_k
=
k
+
num_warmup_microbatches
output_tensor
=
forward_step_helper
(
forward_k
)
# Backward pass.
backward_k
=
k
input_tensor_grad
=
backward_step_helper
(
backward_k
)
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
forward_model_chunk_id
)
if
mpu
.
is_pipeline_last_stage
():
output_tensor
=
None
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
backward_model_chunk_id
)
if
mpu
.
is_pipeline_first_stage
():
input_tensor_grad
=
None
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev
=
True
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
True
)
if
next_forward_model_chunk_id
==
(
num_model_chunks
-
1
):
recv_prev
=
False
next_forward_model_chunk_id
+=
1
else
:
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
+
1
,
forward
=
True
)
recv_next
=
True
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
False
)
if
next_backward_model_chunk_id
==
0
:
recv_next
=
False
next_backward_model_chunk_id
-=
1
else
:
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
+
1
,
forward
=
False
)
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
if
k
==
(
num_microbatches_remaining
-
1
):
recv_prev
=
False
# Communicate tensors.
input_tensor
,
output_tensor_grad
=
\
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
timers
=
timers
)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if
recv_prev
:
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
if
recv_next
:
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
output_tensor_grad
)
# Run cooldown backward passes (flush out pipeline).
if
not
forward_only
:
if
all_warmup_microbatches
:
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
p2p_communication
.
recv_backward
(
timers
))
for
k
in
range
(
num_microbatches_remaining
,
num_microbatches
):
input_tensor_grad
=
backward_step_helper
(
k
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
recv_next
=
True
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
next_backward_model_chunk_id
==
(
num_model_chunks
-
1
):
recv_next
=
False
if
k
==
(
num_microbatches
-
1
):
recv_next
=
False
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
p2p_communication
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
,
timers
))
return
losses_reduced
def
forward_backward_pipelining_without_interleaving
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
Returns dictionary with losses if the last stage, empty dict otherwise."""
timers
=
get_timers
()
assert
len
(
model
)
==
1
model
=
model
[
0
]
# Compute number of warmup microbatches.
num_microbatches
=
get_num_microbatches
()
num_warmup_microbatches
=
\
(
mpu
.
get_pipeline_model_parallel_world_size
()
-
mpu
.
get_pipeline_model_parallel_rank
()
-
1
)
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
=
\
num_microbatches
-
num_warmup_microbatches
input_tensors
=
[]
output_tensors
=
[]
losses_reduced
=
[]
# Run warmup forward passes.
for
i
in
range
(
num_warmup_microbatches
):
input_tensor
=
p2p_communication
.
recv_forward
(
timers
)
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
p2p_communication
.
send_forward
(
output_tensor
,
timers
)
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if
num_microbatches_remaining
>
0
:
input_tensor
=
p2p_communication
.
recv_forward
(
timers
)
# Run 1F1B in steady state.
for
i
in
range
(
num_microbatches_remaining
):
last_iteration
=
(
i
==
(
num_microbatches_remaining
-
1
))
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
if
forward_only
:
p2p_communication
.
send_forward
(
output_tensor
,
timers
)
else
:
output_tensor_grad
=
\
p2p_communication
.
send_forward_recv_backward
(
output_tensor
,
timers
)
# Add input_tensor and output_tensor to end of list, then pop from the
# start of the list for backward pass.
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
if
forward_only
:
if
not
last_iteration
:
input_tensor
=
p2p_communication
.
recv_forward
(
timers
)
else
:
input_tensor
,
output_tensor
=
input_tensors
.
pop
(
0
),
output_tensors
.
pop
(
0
)
input_tensor_grad
=
\
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
if
last_iteration
:
input_tensor
=
None
p2p_communication
.
send_backward
(
input_tensor_grad
,
timers
)
else
:
input_tensor
=
\
p2p_communication
.
send_backward_recv_forward
(
input_tensor_grad
,
timers
)
# Run cooldown backward passes.
if
not
forward_only
:
for
i
in
range
(
num_warmup_microbatches
):
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor_grad
=
p2p_communication
.
recv_backward
(
timers
)
input_tensor_grad
=
\
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
p2p_communication
.
send_backward
(
input_tensor_grad
,
timers
)
return
losses_reduced
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment