Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
fb218c9d
Commit
fb218c9d
authored
Dec 24, 2020
by
mohammad
Browse files
megatron optimizer tested, before working on clip grad
parent
2eaa3ccc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
51 deletions
+21
-51
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+20
-3
megatron/training.py
megatron/training.py
+1
-48
No files found.
megatron/optimizer/optimizer.py
View file @
fb218c9d
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron optimizer."""
from
abc
import
ABC
from
abc
import
ABC
from
abc
import
abstractmethod
from
abc
import
abstractmethod
...
@@ -22,7 +37,7 @@ def get_megatron_optimizer(optimizer, model):
...
@@ -22,7 +37,7 @@ def get_megatron_optimizer(optimizer, model):
if
args
.
loss_scale
:
if
args
.
loss_scale
:
grad_scaler
=
ConstantGradScaler
(
args
.
loss_scale
)
grad_scaler
=
ConstantGradScaler
(
args
.
loss_scale
)
# Dynamic loss scale.
# Dynamic loss scale.
else
:
else
:
grad_scaler
=
DynamicGradScaler
(
grad_scaler
=
DynamicGradScaler
(
initial_scale
=
args
.
initial_loss_scale
,
initial_scale
=
args
.
initial_loss_scale
,
min_scale
=
args
.
min_loss_scale
,
min_scale
=
args
.
min_loss_scale
,
...
@@ -70,7 +85,9 @@ class MegatronGradScaler(ABC):
...
@@ -70,7 +85,9 @@ class MegatronGradScaler(ABC):
class
ConstantGradScaler
(
MegatronGradScaler
):
class
ConstantGradScaler
(
MegatronGradScaler
):
pass
def
update
(
self
,
found_inf
):
pass
class
DynamicGradScaler
(
MegatronGradScaler
):
class
DynamicGradScaler
(
MegatronGradScaler
):
...
...
megatron/training.py
View file @
fb218c9d
...
@@ -39,7 +39,6 @@ from megatron import print_rank_last
...
@@ -39,7 +39,6 @@ from megatron import print_rank_last
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.fp16
import
FP16_Module
from
megatron.fp16
import
FP16_Module
#from megatron.fp16 import FP16_Optimizer
from
megatron.optimizer.optimizer
import
get_megatron_optimizer
from
megatron.optimizer.optimizer
import
get_megatron_optimizer
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
...
@@ -234,15 +233,6 @@ def get_optimizer(model):
...
@@ -234,15 +233,6 @@ def get_optimizer(model):
# Wrap into fp16 optimizer.
# Wrap into fp16 optimizer.
optimizer
=
get_megatron_optimizer
(
optimizer
,
model
)
optimizer
=
get_megatron_optimizer
(
optimizer
,
model
)
'''
optimizer = FP16_Optimizer(optimizer,
static_loss_scale=args.loss_scale,
dynamic_loss_scale=args.dynamic_loss_scale,
dynamic_loss_args={
'scale_window': args.loss_scale_window,
'min_scale': args.min_scale,
'delayed_shift': args.hysteresis})
'''
return
optimizer
return
optimizer
...
@@ -373,13 +363,7 @@ def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_g
...
@@ -373,13 +363,7 @@ def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_g
if
output_tensor_grad
is
None
:
if
output_tensor_grad
is
None
:
output_tensor
=
optimizer
.
scale_loss
(
output_tensor
)
output_tensor
=
optimizer
.
scale_loss
(
output_tensor
)
torch
.
autograd
.
backward
(
output_tensor
,
grad_tensors
=
output_tensor_grad
)
torch
.
autograd
.
backward
(
output_tensor
,
grad_tensors
=
output_tensor_grad
)
'''
if args.fp16 and output_tensor_grad is None:
optimizer.backward(output_tensor, update_master_grads=False,
output_tensor_grad=output_tensor_grad)
else:
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
'''
# Collect the grad of the input_tensor.
# Collect the grad of the input_tensor.
input_tensor_grad
=
None
input_tensor_grad
=
None
if
input_tensor
is
not
None
:
if
input_tensor
is
not
None
:
...
@@ -598,12 +582,6 @@ def train_step(forward_step_func, data_iterator,
...
@@ -598,12 +582,6 @@ def train_step(forward_step_func, data_iterator,
# Set grad to zero.
# Set grad to zero.
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
'''
if args.fp16:
optimizer.zero_grad(set_grads_to_None=True)
else:
optimizer.zero_grad()
'''
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
losses_reduced
=
forward_backward_pipelining
(
losses_reduced
=
forward_backward_pipelining
(
...
@@ -636,31 +614,6 @@ def train_step(forward_step_func, data_iterator,
...
@@ -636,31 +614,6 @@ def train_step(forward_step_func, data_iterator,
group
=
mpu
.
get_embedding_group
())
group
=
mpu
.
get_embedding_group
())
timers
(
'backward-embedding-all-reduce'
).
stop
()
timers
(
'backward-embedding-all-reduce'
).
stop
()
# Update master gradients.
'''
timers('backward-master-grad').start()
if args.fp16:
optimizer.update_master_grads()
timers('backward-master-grad').stop()
'''
# Clipping gradients helps prevent the exploding gradient.
'''
timers('backward-clip-grad').start()
if args.clip_grad > 0.:
if not args.fp16:
named_parameters = model.named_parameters()
parameters = []
parameter_names = []
for parameter_name, parameter in model.named_parameters():
parameters.append(parameter)
parameter_names.append(parameter_name)
mpu.clip_grad_norm(parameters, args.clip_grad,
parameter_names=parameter_names)
else:
optimizer.clip_master_grads(args.clip_grad)
timers('backward-clip-grad').stop()
'''
# Update parameters.
# Update parameters.
timers
(
'optimizer'
).
start
()
timers
(
'optimizer'
).
start
()
update_successfull
=
optimizer
.
step
()
update_successfull
=
optimizer
.
step
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment