Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
8fb2bc8c
"examples/vscode:/vscode.git/clone" did not exist on "4c9959f6b95e20f0a4a0a45c21d168ee7b568dc9"
Commit
8fb2bc8c
authored
Oct 28, 2020
by
Deepak Narayanan
Browse files
Clarifications in comments and minor refactoring to make main training loop more readable
parent
1271fd73
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
52 additions
and
18 deletions
+52
-18
megatron/fp16/fp16.py
megatron/fp16/fp16.py
+0
-1
megatron/model/bert_model.py
megatron/model/bert_model.py
+12
-2
megatron/model/gpt2_model.py
megatron/model/gpt2_model.py
+12
-2
megatron/model/realm_model.py
megatron/model/realm_model.py
+1
-2
megatron/training.py
megatron/training.py
+27
-10
pretrain_ict.py
pretrain_ict.py
+0
-1
No files found.
megatron/fp16/fp16.py
View file @
8fb2bc8c
...
...
@@ -577,7 +577,6 @@ class FP16_Optimizer(object):
# a loss scale that works. After you find a loss scale that works, do a final dummy
# backward pass with retain_graph=False to tear down the graph. Doing this would avoid
# discarding the iteration, but probably wouldn't improve overall efficiency.
# Convert output_tensor to float if it's the loss, otherwise stay in half precision.
self
.
loss_scaler
.
backward
(
output_tensor
,
retain_graph
=
retain_graph
,
output_tensor_grad
=
output_tensor_grad
)
if
update_master_grads
:
...
...
megatron/model/bert_model.py
View file @
8fb2bc8c
...
...
@@ -149,6 +149,17 @@ class BertModelBase(MegatronModule):
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
)
# Parameters are shared between the word embeddings layer, and the heads at
# the end of the model. In a pipelined setup with more than one stage, the
# initial embedding layer and the head are on different workers, so we do
# the following:
# 1. Create a second copy of word_embeddings on the last stage, with initial
# parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that the
# two copies of word_embeddings start off with the same parameter values.
# 3. In the training loop, before an all-reduce between the grads of the two
# word_embeddings layers to ensure that every applied weight update is the
# same on both stages.
if
mpu
.
is_pipeline_last_stage
():
if
not
mpu
.
is_pipeline_first_stage
():
self
.
_word_embeddings_for_head_key
=
'word_embeddings_for_head'
...
...
@@ -169,8 +180,7 @@ class BertModelBase(MegatronModule):
self
.
binary_head
=
get_linear_layer
(
args
.
hidden_size
,
2
,
init_method
)
self
.
_binary_head_key
=
'binary_head'
# Ensure that first and last stages have the same initial embedding weights.
# Ensure that first and last stages have the same initial parameter values.
if
mpu
.
is_pipeline_first_stage
()
or
mpu
.
is_pipeline_last_stage
():
torch
.
distributed
.
all_reduce
(
self
.
word_embeddings_weight
().
data
,
group
=
mpu
.
get_embedding_group
())
...
...
megatron/model/gpt2_model.py
View file @
8fb2bc8c
...
...
@@ -79,6 +79,17 @@ class GPT2ModelBase(MegatronModule):
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
# Parameters are shared between the word embeddings layer, and the heads at
# the end of the model. In a pipelined setup with more than one stage, the
# initial embedding layer and the head are on different workers, so we do
# the following:
# 1. Create a second copy of word_embeddings on the last stage, with initial
# parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that the
# two copies of word_embeddings start off with the same parameter values.
# 3. In the training loop, before an all-reduce between the grads of the two
# word_embeddings layers to ensure that every applied weight update is the
# same on both stages.
if
mpu
.
is_pipeline_last_stage
():
if
not
mpu
.
is_pipeline_first_stage
():
self
.
_word_embeddings_for_head_key
=
'word_embeddings_for_head'
...
...
@@ -89,8 +100,7 @@ class GPT2ModelBase(MegatronModule):
args
.
padded_vocab_size
,
args
.
hidden_size
,
init_method
=
init_method_normal
(
args
.
init_method_std
))
self
.
word_embeddings
.
weight
.
data
.
fill_
(
0
)
# Ensure that first and last stages have the same initial embedding weights.
# Ensure that first and last stages have the same initial parameter values.
if
mpu
.
is_pipeline_first_stage
()
or
mpu
.
is_pipeline_last_stage
():
torch
.
distributed
.
all_reduce
(
self
.
word_embeddings_weight
().
data
,
group
=
mpu
.
get_embedding_group
())
...
...
megatron/model/realm_model.py
View file @
8fb2bc8c
...
...
@@ -18,8 +18,7 @@ def general_ict_model_provider(only_query_model=False, only_block_model=False):
args
=
get_args
()
assert
args
.
ict_head_size
is
not
None
,
\
"Need to specify --ict-head-size to provide an ICTBertModel"
assert
args
.
tensor_model_parallel_size
==
1
,
\
assert
args
.
tensor_model_parallel_size
==
1
and
args
.
pipeline_model_parallel_size
==
1
,
\
"Model parallel size > 1 not supported for ICT"
print_rank_0
(
'building ICTBertModel...'
)
...
...
megatron/training.py
View file @
8fb2bc8c
...
...
@@ -361,6 +361,9 @@ def train_step(forward_step_func, data_iterator,
# Compute number of microbatches in a minibatch.
num_microbatches_in_minibatch
=
args
.
num_microbatches_in_minibatch
# For now, perform training without warmup. Perform forward
# passes for all microbatches, then backward passes for all
# microbatches.
# TODO: Switch to the following schedule to facilitate more
# memory-efficient training.
# num_warmup_microbatches = \
...
...
@@ -369,9 +372,6 @@ def train_step(forward_step_func, data_iterator,
# num_warmup_microbatches = min(
# num_warmup_microbatches,
# num_microbatches_in_minibatch)
# For now, perform training without warmup. Perform forward
# passes for all microbatches, then backward passes for all
# microbatches.
num_warmup_microbatches
=
num_microbatches_in_minibatch
input_tensors
=
[]
...
...
@@ -381,17 +381,31 @@ def train_step(forward_step_func, data_iterator,
# Run warmup forward passes.
timers
(
'forward'
).
start
()
for
i
in
range
(
num_warmup_microbatches
):
forward_step_with_communication
(
forward_step_func
,
data_iterator
,
model
,
input_tensors
,
output_tensors
,
losses_reduced
,
timers
)
if
args
.
pipeline_model_parallel_size
>
1
:
forward_step_with_communication
(
forward_step_func
,
data_iterator
,
model
,
input_tensors
,
output_tensors
,
losses_reduced
,
timers
)
else
:
input_tensor
=
None
loss
,
loss_reduced
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
output_tensor
=
loss
losses_reduced
.
append
(
loss_reduced
)
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
timers
(
'forward'
).
stop
()
# Run cooldown backward passes.
timers
(
'backward'
).
start
()
for
i
in
range
(
num_warmup_microbatches
):
backward_step_with_communication
(
optimizer
,
model
,
input_tensors
,
output_tensors
,
timers
)
if
args
.
pipeline_model_parallel_size
>
1
:
backward_step_with_communication
(
optimizer
,
model
,
input_tensors
,
output_tensors
,
timers
)
else
:
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor_grad
=
None
backward_step
(
optimizer
,
model
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
# All-reduce if needed.
if
args
.
DDP_impl
==
'local'
:
...
...
@@ -400,7 +414,10 @@ def train_step(forward_step_func, data_iterator,
fp32_allreduce
=
args
.
fp32_allreduce
)
timers
(
'allreduce'
).
stop
()
# All-reduce across first and last stages.
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers
(
'backward-embedding-all-reduce'
).
start
()
if
(
mpu
.
is_pipeline_first_stage
()
or
mpu
.
is_pipeline_last_stage
())
and
\
args
.
pipeline_model_parallel_size
>
1
:
...
...
pretrain_ict.py
View file @
8fb2bc8c
...
...
@@ -32,7 +32,6 @@ from megatron.data.realm_dataset_utils import get_ict_batch
def
pretrain_ict_model_provider
():
args
=
get_args
()
assert
args
.
pipeline_model_parallel_size
==
1
,
'pipeline_model_parallel_size must be 1!'
return
general_ict_model_provider
(
False
,
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment