Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
6875dff5
Commit
6875dff5
authored
Feb 14, 2022
by
Lawrence McAfee
Browse files
fix zero_grad; set_to_none = False
parent
1215c420
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
61 additions
and
39 deletions
+61
-39
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+61
-39
No files found.
megatron/optimizer/optimizer.py
View file @
6875dff5
...
...
@@ -768,9 +768,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Allocate shards.
# (Also, collect world DP shard info.)
model_main_dtypes
=
set
([
args
.
params_dtype
,
torch
.
float
])
# model_main_dtypes = set([ args.params_dtype, torch.float ])
model_main_dtypes
=
set
([
torch
.
float
])
# fp32 only, for now
self
.
world_shard_info_groups
=
[]
# world_group_shard_infos ?
self
.
main_param_shard_groups
=
[]
#
self.main_param_shard_groups = []
for
group_index
,
model_param_group
in
enumerate
(
self
.
model_param_groups
):
# pax(0, {
...
...
@@ -820,26 +821,27 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
local_shard_end_index
,
orig_end_index
-
local_shard_start_index
)
if
param_shard_end_index
>
param_shard_start_index
:
# Indexes are relative to local shard start index.
# local_shard_info["param_index_map"][param] = {
# "param" : (
# param_shard_start_index,
# param_shard_end_index,
# ),
# "shard" : (
# param_shard_start_index - local_shard_start_index,
# param_shard_end_index - local_shard_start_index,
# ),
# }
# local_shard_info["param_slice_index_map"][param] = {
# "param_start" :
# param_shard_start_index,
# "shard_start" :
# param_shard_start_index - local_shard_start_index,
# "size":
# param_shard_end_index - param_shard_start_index,
# }
# if param_shard_end_index > param_shard_start_index:
# # Indexes are relative to local shard start index.
# # local_shard_info["param_index_map"][param] = {
# # "param" : (
# # param_shard_start_index,
# # param_shard_end_index,
# # ),
# # "shard" : (
# # param_shard_start_index - local_shard_start_index,
# # param_shard_end_index - local_shard_start_index,
# # ),
# # }
# local_shard_info["param_slice_index_map"][param] = {
# "param_start" :
# param_shard_start_index,
# "shard_start" :
# param_shard_start_index - local_shard_start_index,
# "size":
# param_shard_end_index - param_shard_start_index,
# }
if
shard_end_index
>
shard_start_index
:
local_shard_info
[
"param_slice_index_map"
][
param
]
=
{
"orig_start"
:
orig_start_index
,
"shard_start"
:
shard_start_index
,
...
...
@@ -872,9 +874,15 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Allocate shards.
# (Non-fp32 shards are for convenience; e.g., intermediaries
# between model params and main fp32 shard. Necessary???)
main_param_shards
=
{
ty
:
allocate_shard
(
local_shard_size
,
ty
)
for
ty
in
model_main_dtypes
}
# main_param_shards = {
# ty : allocate_shard(local_shard_size, ty)
# for ty in model_main_dtypes}
main_param_shards
=
{}
for
dtype
in
model_main_dtypes
:
main_param
=
allocate_shard
(
local_shard_size
,
dtype
)
main_param
.
grad
=
allocate_shard
(
local_shard_size
,
dtype
)
# pax(0, {"main_param": main_param})
main_param_shards
[
dtype
]
=
main_param
# self.main_param_shard_groups.append(main_param_shards)
local_shard_info
[
"data"
]
=
main_param_shards
...
...
@@ -891,6 +899,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# recast preexisting per-param state tensors
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
# >>>
# pax(0, {"main param" : self.world_shard_info_groups[0][self.data_parallel_rank]["data"][torch.float]})
# <<<
# def get_loss_scale(self):
# if self.grad_scaler is None:
# return self._scale_one
...
...
@@ -911,7 +923,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
for
main_group
in
self
.
optimizer
.
param_groups
:
params
.
extend
(
main_group
[
"params"
])
_zero_grad_group_helper
(
params
,
set_to_none
)
# _zero_grad_group_helper(params, set_to_none)
_zero_grad_group_helper
(
params
,
set_to_none
=
False
)
# pax(0, {
# "model_param_groups" : self.model_param_groups,
...
...
@@ -920,6 +933,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
def
reduce_gradients
(
self
,
model
):
# >>>
pax
(
0
,
{
"main param"
:
self
.
world_shard_info_groups
[
0
][
self
.
data_parallel_rank
][
"data"
][
torch
.
float
]})
# <<<
# >>>
args
=
get_args
()
# timers = get_timers()
...
...
@@ -962,27 +979,32 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
main_slice_index_map
=
local_shard_info
[
"param_slice_index_map"
]
for
param
,
main_slice_indexes
in
main_slice_index_map
.
items
():
main_param_start_index
=
main_slice_indexes
[
"param_start"
]
main_shard_start_index
=
main_slice_indexes
[
"shard_start"
]
main_slice_size
=
ddd
main_size
=
main_shard_indexesddd
main_slice_orig_start_index
=
main_slice_indexes
[
"orig_start"
]
main_slice_shard_start_index
=
main_slice_indexes
[
"shard_start"
]
main_slice_size
=
main_slice_indexes
[
"size"
]
dtype_model_dict
=
param_model_map
[
param
]
dtype
=
dtype_model_dict
[
"dtype"
]
vmodel
=
dtype_model_dict
[
"model"
]
model_grad_buffer
=
vmodel
.
_grad_buffers
[
dtype
]
model_grad_buffer_start_index
=
\
vmodel
.
_grad_buffer_param_index_map
[
dtype
][
param
][
0
]
# model_grad_buffer_indexes = [ model_grad_buffer_start_index + i
# for i in main_
# model_grad_view = model_grad_buffer.data[
vmodel
.
_grad_buffer_param_index_map
[
dtype
][
param
][
0
]
+
\
main_slice_orig_start_index
pax
(
0
,
{
"model_grad_buffer_indexes"
:
model_grad_buffer_indexes
})
# main_grad_view = self.main_param_shard_groups \
# [group_index][torch.float].grad \
# [shard_indexes["shard"][0]:shard_indexes["shard"][1]]
main_grad_view
=
local_shard_info
[
"data"
][
torch
.
float
]
main_grad_view
=
self
.
main_param_shard_groups
\
[
group_index
][
torch
.
float
].
grad
\
[
shard_indexes
[
"shard"
][
0
]:
shard_indexes
[
"shard"
][
1
]]
pax
(
0
,
{
"local_shard_info"
:
local_shard_info
,
"main_slice_orig_start_index"
:
main_slice_orig_start_index
,
"main_slice_shard_start_index"
:
main_slice_shard_start_index
,
"main_slice_size"
:
main_slice_size
,
"model_grad_buffer_start_index"
:
model_grad_buffer_start_index
,
"main_grad_view"
:
main_grad_view
,
})
pax
(
0
,
{
# "dtype" : dtype,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment