Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
fe2d623e
Commit
fe2d623e
authored
Mar 29, 2022
by
Lawrence McAfee
Browse files
replace triple single w/ triple double quote.
parent
7ed649ed
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
21 deletions
+34
-21
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+23
-12
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+11
-9
No files found.
megatron/optimizer/distrib_optimizer.py
View file @
fe2d623e
...
@@ -29,9 +29,10 @@ from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
...
@@ -29,9 +29,10 @@ from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
class
Range
:
class
Range
:
'''A range represents a start and end points for indexing a shard
"""
A range represents a start and end points for indexing a shard
from a full tensor.
from a full tensor.
'''
"""
def
__init__
(
self
,
start
,
end
):
def
__init__
(
self
,
start
,
end
):
self
.
start
=
start
self
.
start
=
start
self
.
end
=
end
self
.
end
=
end
...
@@ -43,7 +44,7 @@ class Range:
...
@@ -43,7 +44,7 @@ class Range:
class
DistributedOptimizer
(
MixedPrecisionOptimizer
):
class
DistributedOptimizer
(
MixedPrecisionOptimizer
):
'''
Distributed optimizer, for all data types (fp16, bf16, and fp32).
"""
Distributed optimizer, for all data types (fp16, bf16, and fp32).
Arguments:
Arguments:
optimizer: base optimizer such as Adam or SGD
optimizer: base optimizer such as Adam or SGD
...
@@ -70,7 +71,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -70,7 +71,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
always require a grad scaler.
always require a grad scaler.
models: list of models (i.e., the virtual pipelining models). This
models: list of models (i.e., the virtual pipelining models). This
is used by the distributed optimizer for mapping parameters.
is used by the distributed optimizer for mapping parameters.
'''
"""
@
classmethod
@
classmethod
def
build_model_gbuf_param_range_map
(
cls
,
model
,
dtype
,
gbuf_world_range
):
def
build_model_gbuf_param_range_map
(
cls
,
model
,
dtype
,
gbuf_world_range
):
...
@@ -155,8 +156,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -155,8 +156,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
@
classmethod
@
classmethod
def
build_model_param_gbuf_map
(
cls
,
model_gbuf_ranges
):
def
build_model_param_gbuf_map
(
cls
,
model_gbuf_ranges
):
'''Create a reverse of the model_gbuf_ranges, for referencing in
"""
opposite direction.'''
Create a reverse of the model_gbuf_ranges, for referencing in
opposite direction.
"""
param_gbuf_map
=
{}
param_gbuf_map
=
{}
for
model_index
,
model_gbuf_range_map
in
enumerate
(
model_gbuf_ranges
):
for
model_index
,
model_gbuf_range_map
in
enumerate
(
model_gbuf_ranges
):
for
dtype
,
gbuf_range_map
in
model_gbuf_range_map
.
items
():
for
dtype
,
gbuf_range_map
in
model_gbuf_range_map
.
items
():
...
@@ -335,10 +338,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -335,10 +338,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def
get_model_param_range_map
(
self
,
param
):
def
get_model_param_range_map
(
self
,
param
):
'''
"""
Given a model param, get the index sub-range of the param that this
Given a model param, get the index sub-range of the param that this
data-parallel rank owns.
data-parallel rank owns.
'''
"""
model_index
,
dtype
=
self
.
model_param_gbuf_map
[
param
]
model_index
,
dtype
=
self
.
model_param_gbuf_map
[
param
]
gbuf_range_map
=
self
.
model_gbuf_ranges
[
model_index
][
dtype
]
gbuf_range_map
=
self
.
model_gbuf_ranges
[
model_index
][
dtype
]
param_range_map
=
gbuf_range_map
[
"param_map"
][
param
]
param_range_map
=
gbuf_range_map
[
"param_map"
][
param
]
...
@@ -346,10 +349,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -346,10 +349,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def
get_model_parallel_group
(
self
):
def
get_model_parallel_group
(
self
):
"""
With the distributed optimizer, the model parallel group is the
entire world.
"""
return
None
return
None
def
state_dict
(
self
):
def
state_dict
(
self
):
"""
The state dict must contain the fp32-from-float16 shards.
"""
state_dict
=
{}
state_dict
=
{}
state_dict
[
'optimizer'
]
=
self
.
optimizer
.
state_dict
()
state_dict
[
'optimizer'
]
=
self
.
optimizer
.
state_dict
()
if
self
.
grad_scaler
:
if
self
.
grad_scaler
:
...
@@ -424,10 +434,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -424,10 +434,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def
reduce_model_grads
(
self
,
args
,
timers
):
def
reduce_model_grads
(
self
,
args
,
timers
):
'''Note: this is a different order of reduction, versus the non-
"""
distributed optimizer, which reduces: 1) all grads, 2) embedding
Note: this is a different order of reduction, versus the non-
grads.
distributed optimizer, which reduces: 1) all grads, 2) embedding
'''
grads.
"""
# All-reduce embedding grads.
# All-reduce embedding grads.
timers
(
'backward-embedding-all-reduce'
).
start
()
timers
(
'backward-embedding-all-reduce'
).
start
()
...
...
megatron/optimizer/optimizer.py
View file @
fe2d623e
...
@@ -122,7 +122,7 @@ class MegatronOptimizer(ABC):
...
@@ -122,7 +122,7 @@ class MegatronOptimizer(ABC):
def
get_model_parallel_group
(
self
):
def
get_model_parallel_group
(
self
):
'''
Default returned here, but the distributed optimizer overrides this.
'''
"""
Default returned here, but the distributed optimizer overrides this.
"""
return
mpu
.
get_model_parallel_group
()
return
mpu
.
get_model_parallel_group
()
...
@@ -205,19 +205,21 @@ class MegatronOptimizer(ABC):
...
@@ -205,19 +205,21 @@ class MegatronOptimizer(ABC):
def
gather_model_params
(
self
,
args
,
timers
):
def
gather_model_params
(
self
,
args
,
timers
):
'''For the case of a non-distributed-optimizer, there is nothing to
"""
do here.'''
For the case of a non-distributed-optimizer, there is nothing to
do here.
"""
pass
pass
def
allreduce_word_embedding_grads
(
self
,
args
):
def
allreduce_word_embedding_grads
(
self
,
args
):
'''
"""
All-reduce word embedding grads.
All-reduce word embedding grads.
Reduce grads across first and last stages to ensure that word_embeddings
Reduce grads across first and last stages to ensure that word_embeddings
parameters stay in sync. This should only run for models that support
parameters stay in sync. This should only run for models that support
pipelined model parallelism (BERT and GPT-2).
pipelined model parallelism (BERT and GPT-2).
'''
"""
if
mpu
.
is_rank_in_embedding_group
(
ignore_virtual
=
True
)
and
\
if
mpu
.
is_rank_in_embedding_group
(
ignore_virtual
=
True
)
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
...
@@ -240,12 +242,12 @@ class MegatronOptimizer(ABC):
...
@@ -240,12 +242,12 @@ class MegatronOptimizer(ABC):
def
allreduce_position_embedding_grads
(
self
,
args
):
def
allreduce_position_embedding_grads
(
self
,
args
):
'''
"""
All-reduce position_embeddings grad across first (encoder) and
All-reduce position_embeddings grad across first (encoder) and
split (decoder) stages to ensure that position embeddings parameters
split (decoder) stages to ensure that position embeddings parameters
stay in sync. This should only run for T5 models with pipeline
stay in sync. This should only run for T5 models with pipeline
parallelism.
parallelism.
'''
"""
if
mpu
.
is_rank_in_position_embedding_group
()
and
\
if
mpu
.
is_rank_in_position_embedding_group
()
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
and
\
args
.
pipeline_model_parallel_split_rank
is
not
None
:
args
.
pipeline_model_parallel_split_rank
is
not
None
:
...
@@ -259,13 +261,13 @@ class MegatronOptimizer(ABC):
...
@@ -259,13 +261,13 @@ class MegatronOptimizer(ABC):
def
allreduce_embedding_grads
(
self
,
args
):
def
allreduce_embedding_grads
(
self
,
args
):
'''
All-reduce both word and position embeddings.
'''
"""
All-reduce both word and position embeddings.
"""
self
.
allreduce_word_embedding_grads
(
args
)
self
.
allreduce_word_embedding_grads
(
args
)
self
.
allreduce_position_embedding_grads
(
args
)
self
.
allreduce_position_embedding_grads
(
args
)
def
reduce_model_grads
(
self
,
args
,
timers
):
def
reduce_model_grads
(
self
,
args
,
timers
):
'''
All-reduce all grads, and all-reduce embeddings.
'''
"""
All-reduce all grads, and all-reduce embeddings.
"""
# All-reduce if needed.
# All-reduce if needed.
if
args
.
DDP_impl
==
'local'
:
if
args
.
DDP_impl
==
'local'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment