Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b93bef00
Commit
b93bef00
authored
Feb 01, 2022
by
Lawrence McAfee
Browse files
comments, cleanup.
parent
bea16fa3
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
39 additions
and
239 deletions
+39
-239
megatron/arguments.py
megatron/arguments.py
+2
-16
megatron/model/transformer.py
megatron/model/transformer.py
+18
-62
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+11
-59
megatron/p2p_communication.py
megatron/p2p_communication.py
+6
-19
megatron/schedules.py
megatron/schedules.py
+0
-36
megatron/training.py
megatron/training.py
+2
-47
No files found.
megatron/arguments.py
View file @
b93bef00
...
@@ -141,24 +141,9 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -141,24 +141,9 @@ def parse_args(extra_args_provider=None, defaults={},
assert
args
.
num_layers
%
args
.
num_layers_per_virtual_pipeline_stage
==
0
,
\
assert
args
.
num_layers
%
args
.
num_layers_per_virtual_pipeline_stage
==
0
,
\
'number of layers is not divisible by number of layers per virtual '
\
'number of layers is not divisible by number of layers per virtual '
\
'pipeline stage'
'pipeline stage'
# >>>
# args.virtual_pipeline_model_parallel_size = \
# (args.num_layers // args.pipeline_model_parallel_size) // \
# args.num_layers_per_virtual_pipeline_stage
# <<<
args
.
virtual_pipeline_model_parallel_size
=
\
args
.
virtual_pipeline_model_parallel_size
=
\
(
args
.
num_layers
//
args
.
transformer_pipeline_model_parallel_size
)
//
\
(
args
.
num_layers
//
args
.
transformer_pipeline_model_parallel_size
)
//
\
args
.
num_layers_per_virtual_pipeline_stage
args
.
num_layers_per_virtual_pipeline_stage
# >>>
# from lutil import pax
# pax({
# "num_layers" : args.num_layers,
# "pipeline size" : args.pipeline_model_parallel_size,
# "transformer size" : transformer_pipeline_size,
# "num virt layers" : args.num_layers_per_virtual_pipeline_stage,
# "virtual size" : args.virtual_pipeline_model_parallel_size,
# })
# <<<
else
:
else
:
args
.
virtual_pipeline_model_parallel_size
=
None
args
.
virtual_pipeline_model_parallel_size
=
None
...
@@ -707,7 +692,8 @@ def _add_distributed_args(parser):
...
@@ -707,7 +692,8 @@ def _add_distributed_args(parser):
group
.
add_argument
(
'--standalone-embed-stage'
,
action
=
'store_true'
,
group
.
add_argument
(
'--standalone-embed-stage'
,
action
=
'store_true'
,
default
=
False
,
help
=
'If set, *input* embedding layer '
default
=
False
,
help
=
'If set, *input* embedding layer '
'is placed on its own pipeline stage, without any '
'is placed on its own pipeline stage, without any '
'transformer layers.'
)
'transformer layers. (For T5, this flag currently only '
'affects the encoder embedding.)'
)
return
parser
return
parser
...
...
megatron/model/transformer.py
View file @
b93bef00
...
@@ -542,12 +542,20 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -542,12 +542,20 @@ class ParallelTransformerLayer(MegatronModule):
return
output
return
output
# >>>
class
NoopTransformerLayer
(
MegatronModule
):
class
NoopTransformerLayer
(
MegatronModule
):
"""A single 'no-op' transformer layer.
"""A single 'no-op' transformer layer.
The sole purpose of this layer is for when args.standalone_embed_stage
The sole purpose of this layer is for when a standalone embedding layer
== True. ?????
is used (i.e., args.standalone_embed_stage == True). In this case,
zero transformer layers are assigned when pipeline rank == 0. Additionally,
when virtual pipeline rank >= 1, zero total model parameters are created
(virtual rank 0 contains the input embedding). This results in the model's
input and output tensors being the same, which causes an error when
performing certain memory optimiations on the output tensor (e.g.,
deallocating it). Thus, this layer disconnects the input from the output
via a clone. Since ranks containing a no-op layer are generally under-
utilized (both compute and memory), there's no worry of any performance
degredation.
"""
"""
def
__init__
(
self
,
layer_number
):
def
__init__
(
self
,
layer_number
):
...
@@ -558,7 +566,6 @@ class NoopTransformerLayer(MegatronModule):
...
@@ -558,7 +566,6 @@ class NoopTransformerLayer(MegatronModule):
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
inference_params
=
None
):
return
hidden_states
.
clone
()
return
hidden_states
.
clone
()
# <<<
class
ParallelTransformer
(
MegatronModule
):
class
ParallelTransformer
(
MegatronModule
):
...
@@ -583,19 +590,8 @@ class ParallelTransformer(MegatronModule):
...
@@ -583,19 +590,8 @@ class ParallelTransformer(MegatronModule):
self
.
distribute_checkpointed_activations
=
args
.
distribute_checkpointed_activations
self
.
distribute_checkpointed_activations
=
args
.
distribute_checkpointed_activations
# Number of layers.
# Number of layers.
# >>>
# raise Exception("rank %d." % torch.distributed.get_rank())
# <<<
self
.
num_layers
=
mpu
.
get_num_layers
(
self
.
num_layers
=
mpu
.
get_num_layers
(
args
,
args
.
model_type
==
ModelType
.
encoder_and_decoder
)
args
,
args
.
model_type
==
ModelType
.
encoder_and_decoder
)
# >>>
# if not self.pre_process and self.num_layers == 0:
# raise Exception(">>>> t %d, p %d, v %d. <<<<" % (
# mpu.get_tensor_model_parallel_rank(),
# mpu.get_pipeline_model_parallel_rank(),
# mpu.get_virtual_pipeline_model_parallel_rank(),
# ))
# <<<
# Transformer layers.
# Transformer layers.
def
build_layer
(
layer_number
):
def
build_layer
(
layer_number
):
...
@@ -637,28 +633,20 @@ class ParallelTransformer(MegatronModule):
...
@@ -637,28 +633,20 @@ class ParallelTransformer(MegatronModule):
else
:
else
:
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
# >>>
if
self
.
num_layers
==
0
:
if
self
.
num_layers
==
0
:
# when args.standalone_embed_stage == True, virtual pipeline ranks
# When a standalone embedding stage is used (e.g.,
# args.standalone_embed_stage == True), virtual pipeline ranks
# on pipeline rank 0 will have zero transformer layers assigned to
# on pipeline rank 0 will have zero transformer layers assigned to
# them. This will cause a couple optimization techniques to fail:
# them. This results in the model's input and output tensors to be
#
# the same, which will cause failure for certain output tensor
# 1. distributed checkpointing (we
# optimizations (e.g., pipeline output deallocation). To remedy
# 2. pipeline output tensor deallocation (would fail because the
# this, we assign a 'no-op' layer on these ranks, which will
# output tensor is the same object as the input tensor, and
# disconnect the input tensor from the output tensor.
# thus we also deallocate the input tensor, which causes
# autograd.backward to fail)
#
# to remedy this, we assign a 'no-op' layer on these ranks, which
# will pass the data flow through the checkpoint function, and in
# turn also results in the schedule's input and output tensors
# being separate objects.
self
.
num_layers
=
1
self
.
num_layers
=
1
self
.
layers
=
torch
.
nn
.
ModuleList
([
NoopTransformerLayer
(
1
)
])
self
.
layers
=
torch
.
nn
.
ModuleList
([
NoopTransformerLayer
(
1
)
])
else
:
else
:
self
.
layers
=
torch
.
nn
.
ModuleList
(
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
# <<<
if
self
.
post_process
:
if
self
.
post_process
:
# Final layer norm before output.
# Final layer norm before output.
...
@@ -745,18 +733,6 @@ class ParallelTransformer(MegatronModule):
...
@@ -745,18 +733,6 @@ class ParallelTransformer(MegatronModule):
# See set_input_tensor()
# See set_input_tensor()
hidden_states
=
self
.
input_tensor
hidden_states
=
self
.
input_tensor
# >>>
# if not self.pre_process and self.num_layers == 0:
# # raise Exception("tp %d, pp %d, vp %d ... hidden states %s, input tensor %s." % (
# # mpu.get_tensor_model_parallel_rank(),
# # mpu.get_pipeline_model_parallel_rank(),
# # mpu.get_virtual_pipeline_model_parallel_rank(),
# # "--" if hidden_states is None else str(hidden_states.shape),
# # "--" if self.input_tensor is None else str(self.input_tensor.shape),
# # ))
# hidden_states = hidden_states.clone()
# <<<
# Viewless tensor.
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
...
@@ -804,26 +780,6 @@ class ParallelTransformer(MegatronModule):
...
@@ -804,26 +780,6 @@ class ParallelTransformer(MegatronModule):
# Reverting data format change [s b h] --> [b s h].
# Reverting data format change [s b h] --> [b s h].
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
output
=
self
.
final_layernorm
(
hidden_states
)
output
=
self
.
final_layernorm
(
hidden_states
)
# >>>
# if True or output._base is not None:
# # from lutil import pax, tp
# # pax({
# # "hidden_states" : tp(hidden_states),
# # "output" : tp(output),
# # })
# # raise Exception(">>> rank %d, view %d, hid '%s', out '%s'. <<<" %(
# # torch.distributed.get_rank(),
# # output._base is not None,
# # str(hidden_states.shape),
# # str(output.shape),
# # ))
# args = get_args()
# raise Exception(">>> rank %d, hid %d, view %d. <<<" %(
# torch.distributed.get_rank(),
# args.hidden_size,
# output._base is not None,
# ))
# <<<
else
:
else
:
output
=
hidden_states
output
=
hidden_states
...
...
megatron/mpu/initialize.py
View file @
b93bef00
...
@@ -269,9 +269,6 @@ def set_tensor_model_parallel_world_size(world_size):
...
@@ -269,9 +269,6 @@ def set_tensor_model_parallel_world_size(world_size):
def
set_pipeline_model_parallel_world_size
(
world_size
):
def
set_pipeline_model_parallel_world_size
(
world_size
):
# >>>
raise
Exception
(
"hi."
)
# <<<
"""Set the pipeline model parallel size"""
"""Set the pipeline model parallel size"""
global
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
global
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
world_size
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
world_size
...
@@ -290,9 +287,6 @@ def get_pipeline_model_parallel_world_size():
...
@@ -290,9 +287,6 @@ def get_pipeline_model_parallel_world_size():
global
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
global
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
is
not
None
:
if
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
is
not
None
:
return
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
# >>>
# raise Exception("hi.")
# <<<
return
torch
.
distributed
.
get_world_size
(
group
=
get_pipeline_model_parallel_group
())
return
torch
.
distributed
.
get_world_size
(
group
=
get_pipeline_model_parallel_group
())
...
@@ -328,49 +322,34 @@ def get_num_layers(args, is_encoder_and_decoder_model):
...
@@ -328,49 +322,34 @@ def get_num_layers(args, is_encoder_and_decoder_model):
"""Compute the number of transformer layers resident on the current rank."""
"""Compute the number of transformer layers resident on the current rank."""
if
get_pipeline_model_parallel_world_size
()
>
1
:
if
get_pipeline_model_parallel_world_size
()
>
1
:
if
is_encoder_and_decoder_model
:
if
is_encoder_and_decoder_model
:
# >>>
# raise Exception("fix for t5.")
# <<<
assert
args
.
pipeline_model_parallel_split_rank
is
not
None
assert
args
.
pipeline_model_parallel_split_rank
is
not
None
# >>>
# num_ranks_in_encoder = args.pipeline_model_parallel_split_rank
# When a standalone embedding stage is used, a rank is taken from
# +++
# the encoder's ranks, to be used for the encoder's embedding
# layer. This way, the rank referenced by the 'split rank' remains
# the same whether or not a standalone embedding stage is used.
num_ranks_in_encoder
=
(
num_ranks_in_encoder
=
(
args
.
pipeline_model_parallel_split_rank
-
1
args
.
pipeline_model_parallel_split_rank
-
1
if
args
.
standalone_embed_stage
else
if
args
.
standalone_embed_stage
else
args
.
pipeline_model_parallel_split_rank
args
.
pipeline_model_parallel_split_rank
)
)
# <<<
# >>>
# num_ranks_in_decoder = get_pipeline_model_parallel_world_size() - num_ranks_in_encoder
# +++
num_ranks_in_decoder
=
args
.
transformer_pipeline_model_parallel_size
-
num_ranks_in_encoder
num_ranks_in_decoder
=
args
.
transformer_pipeline_model_parallel_size
-
num_ranks_in_encoder
# <<<
# >>>
# raise Exception(">>>> standalone %d, encoder %d, decoder %d. <<<<" % (
# args.standalone_embed_stage,
# num_ranks_in_encoder,
# num_ranks_in_decoder,
# ))
# <<<
assert
args
.
num_layers
%
num_ranks_in_encoder
==
0
,
\
assert
args
.
num_layers
%
num_ranks_in_encoder
==
0
,
\
'num_layers (%d) must be divisible by number of ranks given to encoder (%d)'
%
(
args
.
num_layers
,
num_ranks_in_encoder
)
'num_layers (%d) must be divisible by number of ranks given to encoder (%d)'
%
(
args
.
num_layers
,
num_ranks_in_encoder
)
assert
args
.
num_layers
%
num_ranks_in_decoder
==
0
,
\
assert
args
.
num_layers
%
num_ranks_in_decoder
==
0
,
\
'num_layers (%d) must be divisible by number of ranks given to decoder (%d)'
%
(
args
.
num_layers
,
num_ranks_in_decoder
)
'num_layers (%d) must be divisible by number of ranks given to decoder (%d)'
%
(
args
.
num_layers
,
num_ranks_in_decoder
)
if
is_pipeline_stage_before_split
():
# args):
if
is_pipeline_stage_before_split
():
num_layers
=
args
.
num_layers
//
num_ranks_in_encoder
num_layers
=
args
.
num_layers
//
num_ranks_in_encoder
else
:
else
:
num_layers
=
args
.
num_layers
//
num_ranks_in_decoder
num_layers
=
args
.
num_layers
//
num_ranks_in_decoder
else
:
else
:
# >>>
# transformer_pipeline_size = (
# get_pipeline_model_parallel_world_size() - 1
# if args.standalone_embed_stage else
# get_pipeline_model_parallel_world_size()
# )
# <<<
assert
args
.
num_layers
%
args
.
transformer_pipeline_model_parallel_size
==
0
,
\
assert
args
.
num_layers
%
args
.
transformer_pipeline_model_parallel_size
==
0
,
\
'num_layers must be divisible by transformer_pipeline_model_parallel_size'
'num_layers must be divisible by transformer_pipeline_model_parallel_size'
# When a standalone embedding stage is used, all transformer layers
# are divided among pipeline rank >= 1, while on pipeline rank 0,
# ranks either contain the input embedding layer (virtual pp rank 0),
# or no layers at all (virtual pp rank >= 1).
num_layers
=
(
num_layers
=
(
0
0
if
args
.
standalone_embed_stage
if
args
.
standalone_embed_stage
...
@@ -379,17 +358,6 @@ def get_num_layers(args, is_encoder_and_decoder_model):
...
@@ -379,17 +358,6 @@ def get_num_layers(args, is_encoder_and_decoder_model):
)
)
else
:
else
:
num_layers
=
args
.
num_layers
num_layers
=
args
.
num_layers
# >>>
# from lutil import pax
# pax(7, {
# "rank" : torch.distributed.get_rank(),
# "pipeline rank" : "%d / %d" % (
# get_pipeline_model_parallel_rank(),
# get_pipeline_model_parallel_world_size(),
# ),
# "num_layers" : num_layers,
# })
# <<<
return
num_layers
return
num_layers
...
@@ -438,9 +406,6 @@ def is_rank_in_position_embedding_group():
...
@@ -438,9 +406,6 @@ def is_rank_in_position_embedding_group():
return
rank
in
_POSITION_EMBEDDING_GLOBAL_RANKS
return
rank
in
_POSITION_EMBEDDING_GLOBAL_RANKS
# >>>
# def is_pipeline_stage_before_split(args, rank=None):
# <<<
def
is_pipeline_stage_before_split
(
rank
=
None
):
def
is_pipeline_stage_before_split
(
rank
=
None
):
"""Return True if pipeline stage executes encoder block for a model
"""Return True if pipeline stage executes encoder block for a model
with both encoder and decoder."""
with both encoder and decoder."""
...
@@ -448,11 +413,6 @@ def is_pipeline_stage_before_split(rank=None):
...
@@ -448,11 +413,6 @@ def is_pipeline_stage_before_split(rank=None):
return
True
return
True
if
rank
is
None
:
if
rank
is
None
:
rank
=
get_pipeline_model_parallel_rank
()
rank
=
get_pipeline_model_parallel_rank
()
# >>>
# if args.standalone_embed_stage:
# rank += 1
assert
isinstance
(
rank
,
(
type
(
None
),
int
)),
"rank == <%s>."
%
type
(
rank
).
__name__
# <<<
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
is
None
:
if
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
is
None
:
return
True
return
True
...
@@ -461,9 +421,6 @@ def is_pipeline_stage_before_split(rank=None):
...
@@ -461,9 +421,6 @@ def is_pipeline_stage_before_split(rank=None):
return
False
return
False
# >>>
# def is_pipeline_stage_after_split(args, rank=None):
# <<<
def
is_pipeline_stage_after_split
(
rank
=
None
):
def
is_pipeline_stage_after_split
(
rank
=
None
):
"""Return True if pipeline stage executes decoder block for a model
"""Return True if pipeline stage executes decoder block for a model
with both encoder and decoder."""
with both encoder and decoder."""
...
@@ -471,11 +428,6 @@ def is_pipeline_stage_after_split(rank=None):
...
@@ -471,11 +428,6 @@ def is_pipeline_stage_after_split(rank=None):
return
True
return
True
if
rank
is
None
:
if
rank
is
None
:
rank
=
get_pipeline_model_parallel_rank
()
rank
=
get_pipeline_model_parallel_rank
()
# >>>
# if args.standalone_embed_stage:
# rank += 1
assert
isinstance
(
rank
,
(
type
(
None
),
int
)),
"rank == <%s>."
%
type
(
rank
).
__name__
# <<<
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
is
None
:
if
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
is
None
:
return
True
return
True
...
...
megatron/p2p_communication.py
View file @
b93bef00
...
@@ -136,35 +136,22 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -136,35 +136,22 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# To protect against race condition when using batch_isend_irecv().
# To protect against race condition when using batch_isend_irecv().
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
# >>>
def
make_viewless_tensor
(
t
):
return
mpu
.
make_viewless_tensor
(
t
,
requires_grad
=
True
,
keep_graph
=
False
)
# <<<
# If using scatter-gather optimization, gather smaller chunks.
# If using scatter-gather optimization, gather smaller chunks.
if
not
override_scatter_gather_tensors_in_pipeline
and
\
if
not
override_scatter_gather_tensors_in_pipeline
and
\
args
.
scatter_gather_tensors_in_pipeline
:
args
.
scatter_gather_tensors_in_pipeline
:
if
recv_prev
:
if
recv_prev
:
tensor_recv_prev
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_prev
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_prev
).
view
(
tensor_shape
).
requires_grad_
()
tensor_recv_prev
).
view
(
tensor_shape
).
requires_grad_
()
# >>>
tensor_recv_prev
=
mpu
.
make_viewless_tensor
(
tensor_recv_prev
,
# tensor_recv_prev = mpu.make_viewless_tensor(tensor_recv_prev,
requires_grad
=
True
,
# requires_grad = True,
keep_graph
=
False
)
# keep_graph = False)
# +++
tensor_recv_prev
=
make_viewless_tensor
(
tensor_recv_prev
)
# <<<
if
recv_next
:
if
recv_next
:
tensor_recv_next
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_next
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_next
).
view
(
tensor_shape
).
requires_grad_
()
tensor_recv_next
).
view
(
tensor_shape
).
requires_grad_
()
# >>>
tensor_recv_next
=
mpu
.
make_viewless_tensor
(
tensor_recv_next
,
# tensor_recv_next = mpu.make_viewless_tensor(tensor_recv_next,
requires_grad
=
True
,
# requires_grad = True,
keep_graph
=
False
)
# keep_graph = False)
# +++
tensor_recv_next
=
make_viewless_tensor
(
tensor_recv_next
)
# <<<
return
tensor_recv_prev
,
tensor_recv_next
return
tensor_recv_prev
,
tensor_recv_next
...
...
megatron/schedules.py
View file @
b93bef00
...
@@ -34,25 +34,6 @@ def get_forward_backward_func():
...
@@ -34,25 +34,6 @@ def get_forward_backward_func():
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
forward_backward_func
=
forward_backward_pipelining_with_interleaving
forward_backward_func
=
forward_backward_pipelining_with_interleaving
# >>>
# from lutil import pax
# pax({
# "num microbatches" : get_num_microbatches(),
# "pipeline size" : args.pipeline_model_parallel_size,
# })
# <<<
# >>>
# assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
# 'number of microbatches is not divisible by pipeline-parallel ' \
# 'size when using interleaved schedule'
# assert get_num_microbatches() % \
# args.transformer_pipeline_model_parallel_size == 0, \
# 'number of microbatches (%d) is not divisible by transformer-' \
# 'pipeline-model-parallel-size (%d) when using interleaved ' \
# 'schedule' % (
# get_num_microbatches(),
# args.transformer_pipeline_model_parallel_size,
# )
assert
get_num_microbatches
()
%
\
assert
get_num_microbatches
()
%
\
args
.
pipeline_model_parallel_size
==
0
,
\
args
.
pipeline_model_parallel_size
==
0
,
\
'number of microbatches (%d) is not divisible by pipeline-'
\
'number of microbatches (%d) is not divisible by pipeline-'
\
...
@@ -60,7 +41,6 @@ def get_forward_backward_func():
...
@@ -60,7 +41,6 @@ def get_forward_backward_func():
get_num_microbatches
(),
get_num_microbatches
(),
args
.
pipeline_model_parallel_size
,
args
.
pipeline_model_parallel_size
,
)
)
# <<<
else
:
else
:
forward_backward_func
=
forward_backward_pipelining_without_interleaving
forward_backward_func
=
forward_backward_pipelining_without_interleaving
else
:
else
:
...
@@ -143,9 +123,6 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
...
@@ -143,9 +123,6 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
unwrapped_model
.
set_input_tensor
(
input_tensor
)
unwrapped_model
.
set_input_tensor
(
input_tensor
)
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
)
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
)
# >>>
mpu
.
assert_viewless_tensor
(
output_tensor
)
# <<<
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
output_tensor
=
loss_func
(
output_tensor
)
output_tensor
=
loss_func
(
output_tensor
)
loss
,
loss_reduced
=
output_tensor
loss
,
loss_reduced
=
output_tensor
...
@@ -153,10 +130,6 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
...
@@ -153,10 +130,6 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
losses_reduced
.
append
(
loss_reduced
)
losses_reduced
.
append
(
loss_reduced
)
timers
(
'forward-compute'
).
stop
()
timers
(
'forward-compute'
).
stop
()
# >>>
mpu
.
assert_viewless_tensor
(
output_tensor
)
# <<<
# If T5 model (or other model with encoder and decoder)
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
# downstream as well.
...
@@ -341,15 +314,6 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
...
@@ -341,15 +314,6 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
input_tensor
,
losses_reduced
)
input_tensor
,
losses_reduced
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
# >>>
# if id(input_tensor) == id(output_tensor):
# raise Exception("tp %d, pp %d, vp %d." % (
# mpu.get_tensor_model_parallel_rank(),
# mpu.get_pipeline_model_parallel_rank(),
# mpu.get_virtual_pipeline_model_parallel_rank(),
# ))
# <<<
# if forward-only, no need to save tensors for a backward pass
# if forward-only, no need to save tensors for a backward pass
if
forward_only
:
if
forward_only
:
input_tensors
[
model_chunk_id
].
pop
()
input_tensors
[
model_chunk_id
].
pop
()
...
...
megatron/training.py
View file @
b93bef00
...
@@ -136,14 +136,6 @@ def pretrain(train_valid_test_dataset_provider,
...
@@ -136,14 +136,6 @@ def pretrain(train_valid_test_dataset_provider,
timers
(
'train/valid/test-data-iterators-setup'
).
stop
()
timers
(
'train/valid/test-data-iterators-setup'
).
stop
()
print_datetime
(
'after dataloaders are built'
)
print_datetime
(
'after dataloaders are built'
)
# >>>
# from lutil import pax
# pax({
# "model / len" : len(model),
# # "do_train": args.do_train,
# })
# <<<
# Print setup timing.
# Print setup timing.
print_rank_0
(
'done with setup ...'
)
print_rank_0
(
'done with setup ...'
)
timers
.
log
([
'model-and-optimizer-setup'
,
'train/valid/test-data-iterators-setup'
])
timers
.
log
([
'model-and-optimizer-setup'
,
'train/valid/test-data-iterators-setup'
])
...
@@ -207,14 +199,6 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
...
@@ -207,14 +199,6 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
args
=
get_args
()
args
=
get_args
()
args
.
model_type
=
model_type
args
.
model_type
=
model_type
# >>>
# from lutil import pax
# pax({
# "pipeline world size" : mpu.get_pipeline_model_parallel_world_size(),
# "virtual size" : args.virtual_pipeline_model_parallel_size,
# })
# <<<
# Build model.
# Build model.
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
and
\
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
and
\
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
...
@@ -232,13 +216,6 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
...
@@ -232,13 +216,6 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
)
)
this_model
.
model_type
=
model_type
this_model
.
model_type
=
model_type
model
.
append
(
this_model
)
model
.
append
(
this_model
)
# >>>
# from lutil import pax
# pax({
# "virtual size" : args.virtual_pipeline_model_parallel_size,
# "model" : model,
# })
# <<<
else
:
else
:
pre_process
=
mpu
.
is_pipeline_first_stage
()
pre_process
=
mpu
.
is_pipeline_first_stage
()
post_process
=
mpu
.
is_pipeline_last_stage
()
post_process
=
mpu
.
is_pipeline_last_stage
()
...
@@ -254,10 +231,8 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
...
@@ -254,10 +231,8 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
pre_process
=
rank
==
0
or
rank
==
split_rank
pre_process
=
rank
==
0
or
rank
==
split_rank
post_process
=
(
rank
==
(
split_rank
-
1
))
or
(
post_process
=
(
rank
==
(
split_rank
-
1
))
or
(
rank
==
(
world_size
-
1
))
rank
==
(
world_size
-
1
))
# >>>
add_encoder
=
mpu
.
is_pipeline_stage_before_split
()
add_encoder
=
mpu
.
is_pipeline_stage_before_split
()
# args)
add_decoder
=
mpu
.
is_pipeline_stage_after_split
()
add_decoder
=
mpu
.
is_pipeline_stage_after_split
()
# args)
# <<<
model
=
model_provider_func
(
model
=
model_provider_func
(
pre_process
=
pre_process
,
pre_process
=
pre_process
,
post_process
=
post_process
,
post_process
=
post_process
,
...
@@ -370,21 +345,6 @@ def setup_model_and_optimizer(model_provider_func, model_type):
...
@@ -370,21 +345,6 @@ def setup_model_and_optimizer(model_provider_func, model_type):
model
=
get_model
(
model_provider_func
,
model_type
)
model
=
get_model
(
model_provider_func
,
model_type
)
# >>>
# if mpu.get_tensor_model_parallel_rank() == 0:
# from lutil import pax
# pax({
# # "model" : model,
# "model" : [
# sum(t.nelement() for t in m.parameters())
# for m in model
# ],
# })
# else:
# torch.distributed.barrier()
# exit(0)
# <<<
unwrapped_model
=
unwrap_model
(
model
,
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
(
torchDDP
,
LocalDDP
,
Float16Module
))
optimizer
=
get_megatron_optimizer
(
unwrapped_model
)
optimizer
=
get_megatron_optimizer
(
unwrapped_model
)
...
@@ -952,11 +912,6 @@ def build_train_valid_test_data_iterators(
...
@@ -952,11 +912,6 @@ def build_train_valid_test_data_iterators(
args
.
do_valid
=
flags
[
1
].
item
()
args
.
do_valid
=
flags
[
1
].
item
()
args
.
do_test
=
flags
[
2
].
item
()
args
.
do_test
=
flags
[
2
].
item
()
# >>>
# from lutil import pax
# pax({"hi": "there"})
# <<<
# Build iterators.
# Build iterators.
dl_type
=
args
.
dataloader_type
dl_type
=
args
.
dataloader_type
assert
dl_type
in
[
'single'
,
'cyclic'
]
assert
dl_type
in
[
'single'
,
'cyclic'
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment