Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
82e6730c
Commit
82e6730c
authored
Mar 29, 2022
by
Lawrence McAfee
Browse files
commented optimizer.py.
parent
03ba8eb1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
72 additions
and
2 deletions
+72
-2
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+31
-1
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+41
-1
No files found.
megatron/optimizer/distrib_optimizer.py
View file @
82e6730c
...
...
@@ -29,7 +29,9 @@ from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
class
Range
:
'''A range represents a start and end points for indexing a shard
from a full tensor.
'''
def
__init__
(
self
,
start
,
end
):
self
.
start
=
start
self
.
end
=
end
...
...
@@ -41,6 +43,34 @@ class Range:
class
DistributedOptimizer
(
MixedPrecisionOptimizer
):
'''Distributed optimizer, for all data types (fp16, bf16, and fp32).
Arguments:
optimizer: base optimizer such as Adam or SGD
clip_grad: clip gradeints with this global L2 norm. Note
that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad: return number of zeros in the gradients.
params_have_main_grad: flag indicating if parameters have
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens
for the DDP cases where there is a continuous buffer
holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
Note that main grad is not necessarily in float32.
use_contiguous_buffers_in_local_ddp: if true, the local DDP model
is using a contiguous buffer to hold the model grads.
fp16: if true, the model is running in fp16.
bf16: if true, the model is running in bfloat16.
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constnat gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
models: list of models (i.e., the virtual pipelining models). This
is used by the distributed optimizer for mapping parameters.
'''
@
classmethod
def
build_model_gbuf_param_range_map
(
cls
,
model
,
dtype
,
gbuf_world_range
):
...
...
megatron/optimizer/optimizer.py
View file @
82e6730c
...
...
@@ -203,11 +203,13 @@ class MegatronOptimizer(ABC):
def
step
(
self
,
args
,
timers
):
pass
def
gather_model_params
(
self
,
args
,
timers
):
'''For the case of a non-distributed-optimizer, there is nothing to
do here.'''
pass
def
allreduce_word_embedding_grads
(
self
,
args
):
'''
All-reduce word embedding grads.
...
...
@@ -236,6 +238,7 @@ class MegatronOptimizer(ABC):
grad
=
word_embeddings_weight
.
grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_embedding_group
())
def
allreduce_position_embedding_grads
(
self
,
args
):
'''
All-reduce position_embeddings grad across first (encoder) and
...
...
@@ -254,11 +257,15 @@ class MegatronOptimizer(ABC):
grad
=
unwrapped_model
.
language_model
.
embedding
.
position_embeddings
.
weight
.
main_grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_position_embedding_group
())
def
allreduce_embedding_grads
(
self
,
args
):
'''All-reduce both word and position embeddings.'''
self
.
allreduce_word_embedding_grads
(
args
)
self
.
allreduce_position_embedding_grads
(
args
)
def
reduce_model_grads
(
self
,
args
,
timers
):
'''All-reduce all grads, and all-reduce embeddings.'''
# All-reduce if needed.
if
args
.
DDP_impl
==
'local'
:
...
...
@@ -274,6 +281,34 @@ class MegatronOptimizer(ABC):
class
MixedPrecisionOptimizer
(
MegatronOptimizer
):
"""Base class for both the float-16 and the distributed optimizer.
Arguments:
optimizer: base optimizer such as Adam or SGD
clip_grad: clip gradeints with this global L2 norm. Note
that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad: return number of zeros in the gradients.
params_have_main_grad: flag indicating if parameters have
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens
for the DDP cases where there is a continuous buffer
holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
Note that main grad is not necessarily in float32.
use_contiguous_buffers_in_local_ddp: if true, the local DDP model
is using a contiguous buffer to hold the model grads.
fp16: if true, the model is running in fp16.
bf16: if true, the model is running in bfloat16.
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constnat gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
models: list of models (i.e., the virtual pipelining models). This
is used by the distributed optimizer for mapping parameters.
"""
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
...
...
@@ -378,7 +413,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
timers
(
'optimizer-clip-main-grad'
).
stop
()
#
c
ount the zeros in the grads
#
C
ount the zeros in the grads
.
timers
(
'optimizer-count-zeros'
).
start
()
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
self
.
log_num_zeros_in_grad
else
None
...
...
@@ -415,12 +450,17 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
Note that main grad is not necessarily in float32.
use_contiguous_buffers_in_local_ddp: if true, the local DDP model
is using a contiguous buffer to hold the model grads.
fp16: if true, the model is running in fp16.
bf16: if true, the model is running in bfloat16.
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constnat gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
models: list of models (i.e., the virtual pipelining models). This
is used by the distributed optimizer for mapping parameters.
"""
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment