Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
4b843668
Commit
4b843668
authored
Feb 23, 2022
by
Lawrence McAfee
Browse files
fixed param_world_shard bug.
parent
c13c0a3e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
17 deletions
+25
-17
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+25
-17
No files found.
megatron/optimizer/optimizer.py
View file @
4b843668
...
@@ -756,7 +756,9 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -756,7 +756,9 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Add shard, if within range.
# Add shard, if within range.
if
param_local_end
>
param_local_start
:
if
param_local_end
>
param_local_start
:
param_local_shard
=
Shard
(
param_local_start
,
param_local_end
)
param_local_shard
=
Shard
(
param_local_start
,
param_local_end
)
param_world_shard
=
param_local_shard
.
normalize
(
param_world_start
)
# param_world_shard = param_local_shard.normalize(param_world_start)
param_world_shard
=
param_local_shard
.
normalize
(
param_local_start
+
gbuf_world_shard
.
start
)
sub_param_start
=
max
(
0
,
gbuf_world_shard
.
start
-
param_world_start
)
sub_param_start
=
max
(
0
,
gbuf_world_shard
.
start
-
param_world_start
)
sub_param_shard
=
param_local_shard
.
normalize
(
sub_param_start
)
sub_param_shard
=
param_local_shard
.
normalize
(
sub_param_start
)
param_shard_map
[
param
]
=
{
param_shard_map
[
param
]
=
{
...
@@ -764,6 +766,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -764,6 +766,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
"gbuf_local"
:
param_local_shard
,
"gbuf_local"
:
param_local_shard
,
"param"
:
sub_param_shard
,
"param"
:
sub_param_shard
,
}
}
pax
(
1
,
{
"gbuf_world_shard"
:
gbuf_world_shard
,
"param shards"
:
param_shard_map
[
param
],
})
# >>>
# >>>
# if param_world_start < gbuf_world_shard.start:
# if param_world_start < gbuf_world_shard.start:
# pax({"param shards": param_shard_map[param]})
# pax({"param shards": param_shard_map[param]})
...
@@ -806,7 +812,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -806,7 +812,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
"param_map"
:
param_shard_map
,
"param_map"
:
param_shard_map
,
}
}
# pax(
0
, {"data": data})
# pax(
1
, {"data": data})
return
data
return
data
...
@@ -1155,9 +1161,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1155,9 +1161,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ],
# ],
# })
# })
pax
(
1
,
{
pax
(
1
,
{
"data_parallel_rank"
:
data_parallel_rank
,
"main params"
:
self
.
get_main_params
(),
"main params"
:
self
.
get_main_params
(),
"model params / world"
:
self
.
get_world_model_params
(),
# "model params / world" : self.get_world_model_params(),
"gbuf_view_item"
:
tp
(
gbuf_view
[
data_parallel_rank
]),
**
{
"gbuf_view_items / %d"
%
i
:
v
[
2
]
for
i
,
v
in
enumerate
(
gbuf_view_items
)},
# "gbuf_view_item" : tp(gbuf_view[data_parallel_rank]),
# "model params / local" : self.get_local_model_param_views(),
# "model params / local" : self.get_local_model_param_views(),
})
})
...
@@ -1311,19 +1319,19 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1311,19 +1319,19 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
model_view
.
detach
().
copy_
(
main_view
)
model_view
.
detach
().
copy_
(
main_view
)
# Debug.
# Debug.
#
pax(
0
, {
pax
(
1
,
{
#
"group_index" : group_index,
"group_index"
:
group_index
,
#
"group_shard" : group_shard,
"group_shard"
:
group_shard
,
#
"param" : tp(param),
"
model_
param"
:
tp
(
model_
param
),
#
"model_index" : model_index,
"model_index"
:
model_index
,
#
"
gbuf_
dtype" : str(
gbuf_
dtype),
"dtype"
:
str
(
dtype
),
#
"model_
grad_tensor
" : tp(model_
grad_tensor
),
"model_
param
"
:
tp
(
model_
param
),
#
"main_
grad_tensor
" : tp(main_
grad_tensor
),
"main_
param
"
:
tp
(
main_
param
),
#
"model_
grad_
view" : tp(model_
grad_
view),
"model_view"
:
tp
(
model_view
),
#
"main_
grad_
view" : tp(main_
grad_
view),
"main_view"
:
tp
(
main_view
),
#
"model_shard" : str(model_shard),
"model_shard"
:
str
(
model_shard
),
#
"main_shard" : str(main_shard),
"main_shard"
:
str
(
main_shard
),
#
})
})
# pax(0, {
# pax(0, {
# "model_gbuf_shards" : self.model_gbuf_shards,
# "model_gbuf_shards" : self.model_gbuf_shards,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment