Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
97ba5c0e
Commit
97ba5c0e
authored
Dec 26, 2020
by
mohammad
Browse files
load and save state dicts added
parent
0888a3e1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
45 additions
and
10 deletions
+45
-10
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+1
-3
megatron/optimizer/grad_scaler.py
megatron/optimizer/grad_scaler.py
+22
-2
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+22
-5
No files found.
megatron/optimizer/__init__.py
View file @
97ba5c0e
...
@@ -25,7 +25,6 @@ def _get_params_for_weight_decay_optimization(module):
...
@@ -25,7 +25,6 @@ def _get_params_for_weight_decay_optimization(module):
"""Divide params into with-weight-decay and without-weight-decay groups.
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
Layernorms and baises will have no weight decay but the rest will.
"""
"""
args
=
get_args
()
args
=
get_args
()
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
)
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
)
...
@@ -48,7 +47,6 @@ def _get_params_for_weight_decay_optimization(module):
...
@@ -48,7 +47,6 @@ def _get_params_for_weight_decay_optimization(module):
def
get_megatron_optimizer
(
model
):
def
get_megatron_optimizer
(
model
):
args
=
get_args
()
args
=
get_args
()
# Base optimizer.
# Base optimizer.
...
@@ -77,4 +75,4 @@ def get_megatron_optimizer(model):
...
@@ -77,4 +75,4 @@ def get_megatron_optimizer(model):
args
.
clip_grad
)
args
.
clip_grad
)
# FP32.
# FP32.
return
FP32Optimizer
(
optimizer
,
model
,
args
.
clip_grad
)
return
FP32Optimizer
(
optimizer
,
args
.
clip_grad
)
megatron/optimizer/grad_scaler.py
View file @
97ba5c0e
...
@@ -40,7 +40,6 @@ class MegatronGradScaler(ABC):
...
@@ -40,7 +40,6 @@ class MegatronGradScaler(ABC):
def
update
(
self
,
found_inf
):
def
update
(
self
,
found_inf
):
pass
pass
'''
@
abstractmethod
@
abstractmethod
def
state_dict
(
self
):
def
state_dict
(
self
):
pass
pass
...
@@ -48,7 +47,7 @@ class MegatronGradScaler(ABC):
...
@@ -48,7 +47,7 @@ class MegatronGradScaler(ABC):
@
abstractmethod
@
abstractmethod
def
load_state_dict
(
self
,
state_dict
):
def
load_state_dict
(
self
,
state_dict
):
pass
pass
'''
class
ConstantGradScaler
(
MegatronGradScaler
):
class
ConstantGradScaler
(
MegatronGradScaler
):
...
@@ -56,6 +55,13 @@ class ConstantGradScaler(MegatronGradScaler):
...
@@ -56,6 +55,13 @@ class ConstantGradScaler(MegatronGradScaler):
def
update
(
self
,
found_inf
):
def
update
(
self
,
found_inf
):
pass
pass
def
state_dict
(
self
):
return
dict
()
def
load_state_dict
(
self
,
state_dict
):
pass
class
DynamicGradScaler
(
MegatronGradScaler
):
class
DynamicGradScaler
(
MegatronGradScaler
):
...
@@ -111,3 +117,17 @@ class DynamicGradScaler(MegatronGradScaler):
...
@@ -111,3 +117,17 @@ class DynamicGradScaler(MegatronGradScaler):
self
.
_hysteresis_tracker
=
self
.
hysteresis
self
.
_hysteresis_tracker
=
self
.
hysteresis
# and scale up the loss scale.
# and scale up the loss scale.
self
.
_scale
=
self
.
_scale
*
self
.
growth_factor
self
.
_scale
=
self
.
_scale
*
self
.
growth_factor
def
state_dict
(
self
):
state_dict
=
{}
state_dict
[
'scale'
]
=
self
.
_scale
state_dict
[
'growth_tracker'
]
=
self
.
_growth_tracker
state_dict
[
'hysteresis_tracker'
]
=
self
.
_hysteresis_tracker
return
state_dict
def
load_state_dict
(
self
,
state_dict
):
self
.
_scale
=
state_dict
[
'scale'
].
cuda
(
torch
.
cuda
.
current_device
())
self
.
_growth_tracker
=
state_dict
[
'growth_tracker'
]
self
.
_hysteresis_tracker
=
state_dict
[
'hysteresis_tracker'
]
megatron/optimizer/optimizer.py
View file @
97ba5c0e
...
@@ -145,7 +145,6 @@ class MegatronOptimizer(ABC):
...
@@ -145,7 +145,6 @@ class MegatronOptimizer(ABC):
def
step
(
self
):
def
step
(
self
):
pass
pass
'''
@
abstractmethod
@
abstractmethod
def
state_dict
(
self
):
def
state_dict
(
self
):
pass
pass
...
@@ -153,7 +152,6 @@ class MegatronOptimizer(ABC):
...
@@ -153,7 +152,6 @@ class MegatronOptimizer(ABC):
@
abstractmethod
@
abstractmethod
def
load_state_dict
(
self
,
state_dict
):
def
load_state_dict
(
self
,
state_dict
):
pass
pass
'''
# Promote state so it can be retrieved or set via
# Promote state so it can be retrieved or set via
# "optimizer_instance.state"
# "optimizer_instance.state"
...
@@ -180,7 +178,6 @@ class MegatronOptimizer(ABC):
...
@@ -180,7 +178,6 @@ class MegatronOptimizer(ABC):
class
FP16OptimizerWithFP16Params
(
MegatronOptimizer
):
class
FP16OptimizerWithFP16Params
(
MegatronOptimizer
):
def
__init__
(
self
,
optimizer
,
grad_scaler
,
clip_grad
):
def
__init__
(
self
,
optimizer
,
grad_scaler
,
clip_grad
):
super
(
FP16OptimizerWithFP16Params
,
self
).
__init__
(
optimizer
)
super
(
FP16OptimizerWithFP16Params
,
self
).
__init__
(
optimizer
)
...
@@ -369,12 +366,32 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -369,12 +366,32 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
return
True
return
True
def
state_dict
(
self
):
state_dict
=
{}
state_dict
[
'optimizer'
]
=
self
.
optimizer
.
state_dict
()
state_dict
[
'grad_scaler'
]
=
self
.
grad_scaler
.
state_dict
()
state_dict
[
'fp32_from_fp16_params'
]
=
self
.
fp32_from_fp16_groups
return
state_dict
def
load_state_dict
(
self
,
state_dict
):
# Defer to the class to load.
self
.
optimizer
.
load_state_dict
(
state_dict
[
'optimizer'
])
self
.
grad_scaler
.
load_state_dict
(
state_dict
[
'grad_scaler'
])
# Copy data for the master params.
for
current_group
,
saved_group
in
zip
(
self
.
fp32_from_fp16_groups
,
state_dict
[
'fp32_from_fp16_params'
]):
for
current_param
,
saved_param
in
zip
(
current_group
,
saved_group
):
current_param
.
data
.
copy_
(
saved_param
.
data
)
class
FP32Optimizer
(
MegatronOptimizer
):
class
FP32Optimizer
(
MegatronOptimizer
):
def
__init__
(
self
,
optimizer
,
model
,
clip_grad
):
def
__init__
(
self
,
optimizer
,
clip_grad
):
super
(
FP32Optimizer
,
self
).
__init__
(
optimizer
)
super
(
FP32Optimizer
,
self
).
__init__
(
optimizer
)
self
.
model
=
model
self
.
clip_grad
=
clip_grad
self
.
clip_grad
=
clip_grad
self
.
_scale
=
torch
.
cuda
.
FloatTensor
([
1.0
])
self
.
_scale
=
torch
.
cuda
.
FloatTensor
([
1.0
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment