Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
5706ba42
Commit
5706ba42
authored
Feb 11, 2022
by
Lawrence McAfee
Browse files
bit more progress
parent
f48e1f29
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
118 additions
and
82 deletions
+118
-82
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+118
-82
No files found.
megatron/optimizer/optimizer.py
View file @
5706ba42
...
...
@@ -121,11 +121,21 @@ class MegatronOptimizer(ABC):
return
self
.
get_loss_scale
()
*
loss
@
abstractmethod
def
reduce_gradients
(
self
):
pass
@
abstractmethod
def
step
(
self
):
pass
@
abstractmethod
def
gather_params
(
self
):
pass
@
abstractmethod
def
reload_model_params
(
self
):
"""Refreshes any internal state from the current model parameters.
...
...
@@ -170,36 +180,13 @@ class MegatronOptimizer(ABC):
class
Float16OptimizerWithFloat16Params
(
MegatronOptimizer
):
"""Float16 optimizer for fp16 and bf16 data types.
Arguments:
optimizer: base optimizer such as Adam or SGD
clip_grad: clip gradeints with this global L2 norm. Note
that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad: return number of zeros in the gradients.
params_have_main_grad: flag indicating if parameters have
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens
for the DDP cases where there is a continuous buffer
holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
Note that main grad is not necessarily in float32.
bf16: if true, the model is running in bfloat16.
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constnat gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
"""
class
BaseFloat16Optimizer
(
MegatronOptimizer
):
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
bf16
,
grad_scaler
):
super
(
Float16OptimizerWithFloat16Params
,
self
).
__init__
(
super
().
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
)
...
...
@@ -228,6 +215,48 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
if
self
.
grad_scaler
is
None
:
self
.
_scale_one
=
torch
.
cuda
.
FloatTensor
([
1.0
])
def
get_loss_scale
(
self
):
if
self
.
grad_scaler
is
None
:
return
self
.
_scale_one
return
self
.
grad_scaler
.
scale
# class Float16OptimizerWithFloat16Params(MegatronOptimizer):
class
Float16OptimizerWithFloat16Params
(
BaseFloat16Optimizer
):
"""Float16 optimizer for fp16 and bf16 data types.
Arguments:
optimizer: base optimizer such as Adam or SGD
clip_grad: clip gradeints with this global L2 norm. Note
that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad: return number of zeros in the gradients.
params_have_main_grad: flag indicating if parameters have
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens
for the DDP cases where there is a continuous buffer
holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
Note that main grad is not necessarily in float32.
bf16: if true, the model is running in bfloat16.
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constnat gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
"""
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
bf16
,
grad_scaler
):
super
().
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
bf16
,
grad_scaler
)
# ======================
# main parameter stuff
# ======================
...
...
@@ -319,29 +348,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# recast preexisting per-param state tensors
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
# >>>
# from lutil import pax
# pax(0, {
# # "float16_groups / len" : [ len(g) for g in self.float16_groups ],
# # "fp32_from_float16_groups / len" :
# # [ len(g) for g in self.fp32_from_float16_groups ],
# # "float16_groups / 0" : self.float16_groups[0],
# # "float16_groups / 1" : self.float16_groups[1],
# # "fp32_from_float16_groups / 0" : self.fp32_from_float16_groups[0],
# # "fp32_from_float16_groups / 1" : self.fp32_from_float16_groups[1],
# # "fp32_from_float32_groups" : self.fp32_from_fp32_groups,
# "optimizer" : self.optimizer,
# # "optimizer / sd" : self.optimizer.state_dict(),
# # "optimizer / state" : self.optimizer.state_dict()["state"],
# # "optimizer / pg" : self.optimizer.state_dict()["param_groups"],
# # "optimizer / pg / 0" : self.optimizer.state_dict()["param_groups"][0],
# # "optimizer / pg / 1" : self.optimizer.state_dict()["param_groups"][1],
# "optimizer -> pg" : optimizer.param_groups,
# "optimizer -> pg / 0" : optimizer.param_groups[0]["params"],
# "optimizer -> pg / 1" : optimizer.param_groups[1]["params"],
# })
# <<<
def
zero_grad
(
self
,
set_to_none
=
True
):
"""We only need to zero the model related parameters, i.e.,
...
...
@@ -357,12 +363,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
_zero_grad_group_helper
(
group
,
set_to_none
)
def
get_loss_scale
(
self
):
if
self
.
grad_scaler
is
None
:
return
self
.
_scale_one
return
self
.
grad_scaler
.
scale
# >>>
def
reduce_gradients
(
self
,
model
):
...
...
@@ -658,7 +658,8 @@ from lutil import pax, tp
# <<<
# class Float16DistributedOptimizer(Float16OptimizerWithFloat16Params):
class
Float16DistributedOptimizer
(
MegatronOptimizer
):
# class Float16DistributedOptimizer(MegatronOptimizer):
class
Float16DistributedOptimizer
(
BaseFloat16Optimizer
):
# >>>
@
classmethod
...
...
@@ -702,7 +703,8 @@ class Float16DistributedOptimizer(MegatronOptimizer):
super
().
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
)
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
bf16
,
grad_scaler
)
# >>>
# self.test_reduce_scatter()
...
...
@@ -759,34 +761,41 @@ class Float16DistributedOptimizer(MegatronOptimizer):
allocate_shard
=
lambda
shard_size
,
dtype
:
torch
.
empty
(
(
shard_size
,),
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
())
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
True
)
# return torch.nn.Parameter ?
# allocate_shard = lambda dtype : MemoryBuffer(self.shard_size, dtype)
# Collect DP world shard infos, per group.
# Allocate shards.
# (Also, collect world DP shard info.)
model_main_dtypes
=
set
([
args
.
params_dtype
,
torch
.
float
])
self
.
world_shard_info_groups
=
[]
# world_group_shard_infos ?
self
.
main_param_shard_groups
=
[]
for
model_param_group
_size
in
model_param_group
_size
s
:
for
group_index
,
model_param_group
in
enumerate
(
self
.
model_param_groups
)
:
max_world_shard_size
=
int
(
math
.
ceil
(
model_param_group_size
/
model_param_size
=
model_param_group
[
"size"
]
max_world_shard_size
=
int
(
math
.
ceil
(
model_param_size
/
self
.
data_parallel_world_size
))
#
Group
shard infos.
shard_infos
=
[]
#
DP world
shard infos.
world_
shard_infos
=
[]
for
r
in
range
(
self
.
data_parallel_world_size
):
shard_start_index
=
r
*
max_shard_size
shard_end_index
=
min
(
self
.
tota
l_param_size
,
shard_start_index
+
max_shard_size
)
shard_infos
.
append
({
shard_start_index
=
r
*
max_
world_
shard_size
shard_end_index
=
min
(
mode
l_param_size
,
shard_start_index
+
max_
world_
shard_size
)
world_
shard_infos
.
append
({
"start"
:
shard_start_index
,
"end"
:
shard_end_index
,
"size"
:
shard_end_index
-
shard_start_index
,
})
self
.
world_shard_info_groups
.
append
(
shard_infos
)
self
.
world_shard_info_groups
.
append
(
world_shard_infos
)
# pax(0, {"world_shard_infos": world_shard_infos})
# Allocate shards.
local_shard_size
=
\
self
.
world_shard_infos
[
self
.
data_parallel_rank
][
"size"
]
# (Non-fp32 shards are for convenience; e.g., intermediaries
# between model params and main fp32 shard. Necessary???)
local_shard_size
=
world_shard_infos
[
self
.
data_parallel_rank
][
"size"
]
# # self.main_param_shard = allocate_shard(torch.float)
# # self.main_grad_shard = allocate_shard(torch.float)
...
...
@@ -795,29 +804,50 @@ class Float16DistributedOptimizer(MegatronOptimizer):
# self.adam_m_shard = allocate_shard(torch.float)
# self.adam_v_shard = allocate_shard(torch.float)
self
.
main_param_shard_groups
.
append
({
ty
:
allocate_shard
(
ty
)
for
ty
in
model_main_dtypes
})
main_param_shards
=
{
ty
:
allocate_shard
(
local_shard_size
,
ty
)
for
ty
in
model_main_dtypes
}
self
.
main_param_shard_groups
.
append
(
main_param_shards
)
# Update optimizer group.
self
.
optimizer
.
param_groups
[
group_index
][
"params"
]
=
\
[
main_param_shards
[
torch
.
float
]
]
# >>>
# pax(0, {
# "total_param_size" : self.total_param_size,
# "max_shard_size" : max_shard_size,
# "shard_infos" : self.shard_infos,
# "shard_size" : shard_size,
# "param_shard_map" : self.param_shard_map,
# "param_groups" : self.optimizer.param_groups,
# "params" : self.optimizer.param_groups[group_index]["params"],
# })
# <<<
def
get_loss_scale
(
self
):
raise
Exception
(
"hi."
)
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
# def get_loss_scale(self):
# if self.grad_scaler is None:
# return self._scale_one
# return self.grad_scaler.scale
def
load_state_dict
(
self
):
raise
Exception
(
"hi."
)
def
reload_model_params
(
self
):
raise
Exception
(
"hi."
)
def
state_dict
(
self
):
raise
Exception
(
"hi."
)
def
zero_grad
(
self
):
raise
Exception
(
"hi."
)
def
zero_grad
(
self
,
set_to_none
=
True
):
params
=
[]
for
model_param_group
in
self
.
model_param_groups
:
params
.
extend
(
model_param_group
[
"offset_map"
].
keys
())
for
main_group
in
self
.
optimizer
.
param_groups
:
params
.
extend
(
main_group
[
"params"
])
_zero_grad_group_helper
(
params
,
set_to_none
)
# pax(0, {
# "model_param_groups" : self.model_param_groups,
# "params" : params,
# })
def
reduce_gradients
(
self
,
model
):
...
...
@@ -880,10 +910,16 @@ class Float16DistributedOptimizer(MegatronOptimizer):
})
# <<<
def
step
(
self
):
raise
Exception
(
"step."
)
def
gather_params
(
self
):
raise
Exception
(
"gather params."
)
# <<<
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment