Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
1fa6990c
Commit
1fa6990c
authored
Jan 31, 2022
by
Lawrence McAfee
Browse files
added 'no-op' layer, to replace transformer layer when num_layers == 0.
parent
5bc9f889
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
55 additions
and
8 deletions
+55
-8
megatron/model/transformer.py
megatron/model/transformer.py
+49
-2
megatron/schedules.py
megatron/schedules.py
+6
-6
No files found.
megatron/model/transformer.py
View file @
1fa6990c
...
@@ -542,6 +542,25 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -542,6 +542,25 @@ class ParallelTransformerLayer(MegatronModule):
return
output
return
output
# >>>
class
NoopTransformerLayer
(
MegatronModule
):
"""A single 'no-op' transformer layer.
The sole purpose of this layer is for when args.standalone_embedding_stage
== True. ?????
"""
def
__init__
(
self
,
layer_number
):
super
().
__init__
()
self
.
layer_number
=
layer_number
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
return
hidden_states
.
clone
()
# <<<
class
ParallelTransformer
(
MegatronModule
):
class
ParallelTransformer
(
MegatronModule
):
"""Transformer class."""
"""Transformer class."""
...
@@ -569,6 +588,14 @@ class ParallelTransformer(MegatronModule):
...
@@ -569,6 +588,14 @@ class ParallelTransformer(MegatronModule):
# <<<
# <<<
self
.
num_layers
=
mpu
.
get_num_layers
(
self
.
num_layers
=
mpu
.
get_num_layers
(
args
,
args
.
model_type
==
ModelType
.
encoder_and_decoder
)
args
,
args
.
model_type
==
ModelType
.
encoder_and_decoder
)
# >>>
# if not self.pre_process and self.num_layers == 0:
# raise Exception(">>>> t %d, p %d, v %d. <<<<" % (
# mpu.get_tensor_model_parallel_rank(),
# mpu.get_pipeline_model_parallel_rank(),
# mpu.get_virtual_pipeline_model_parallel_rank(),
# ))
# <<<
# Transformer layers.
# Transformer layers.
def
build_layer
(
layer_number
):
def
build_layer
(
layer_number
):
...
@@ -610,8 +637,28 @@ class ParallelTransformer(MegatronModule):
...
@@ -610,8 +637,28 @@ class ParallelTransformer(MegatronModule):
else
:
else
:
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
# >>>
if
self
.
num_layers
==
0
:
# when args.standalone_embed_stage == True, virtual pipeline ranks
# on pipeline rank 0 will have zero transformer layers assigned to
# them. This will cause a couple optimization techniques to fail:
#
# 1. distributed checkpointing (we
# 2. pipeline output tensor deallocation (would fail because the
# output tensor is the same object as the input tensor, and
# thus we also deallocate the input tensor, which causes
# autograd.backward to fail)
#
# to remedy this, we assign a 'no-op' layer on these ranks, which
# will pass the data flow through the checkpoint function, and in
# turn also results in the schedule's input and output tensors
# being separate objects.
self
.
num_layers
=
1
self
.
layers
=
torch
.
nn
.
ModuleList
([
NoopTransformerLayer
(
1
)
])
else
:
self
.
layers
=
torch
.
nn
.
ModuleList
(
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
# <<<
if
self
.
post_process
:
if
self
.
post_process
:
# Final layer norm before output.
# Final layer norm before output.
...
...
megatron/schedules.py
View file @
1fa6990c
...
@@ -335,12 +335,12 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
...
@@ -335,12 +335,12 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
# >>>
# >>>
if
id
(
input_tensor
)
==
id
(
output_tensor
):
#
if id(input_tensor) == id(output_tensor):
raise
Exception
(
"tp %d, pp %d, vp %d."
%
(
#
raise Exception("tp %d, pp %d, vp %d." % (
mpu
.
get_tensor_model_parallel_rank
(),
#
mpu.get_tensor_model_parallel_rank(),
mpu
.
get_pipeline_model_parallel_rank
(),
#
mpu.get_pipeline_model_parallel_rank(),
mpu
.
get_virtual_pipeline_model_parallel_rank
(),
#
mpu.get_virtual_pipeline_model_parallel_rank(),
))
#
))
# <<<
# <<<
# if forward-only, no need to save tensors for a backward pass
# if forward-only, no need to save tensors for a backward pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment