Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
7ed649ed
Commit
7ed649ed
authored
Mar 29, 2022
by
Lawrence McAfee
Browse files
renamed: full_ -> model_.
parent
82e6730c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
38 deletions
+44
-38
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+44
-38
No files found.
megatron/optimizer/distrib_optimizer.py
View file @
7ed649ed
...
@@ -200,12 +200,14 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -200,12 +200,14 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
param_gbuf_map
,
param_gbuf_map
,
opt_group_ranges
):
opt_group_ranges
):
# Three groups of parameters:
# Parameter groups:
# float16_groups: original float16 parameters
# model_float16_groups: original float16 parameters
# fp32_from_float16_groups: fp32 copy of float16 parameters
# model_fp32_groups: original fp32 parameters
# fp32_groups: original fp32 parameters
# shard_float16_groups: shards of original float16 parameters
full_float16_groups
=
[]
# shard_fp32_groups: shards of original fp32 parameters
full_fp32_groups
=
[]
# shard_fp32_from_float16_groups: fp32 copy of float16 parameters
model_float16_groups
=
[]
model_fp32_groups
=
[]
shard_float16_groups
=
[]
shard_float16_groups
=
[]
shard_fp32_groups
=
[]
shard_fp32_groups
=
[]
shard_fp32_from_float16_groups
=
[]
shard_fp32_from_float16_groups
=
[]
...
@@ -214,13 +216,13 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -214,13 +216,13 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for
group_index
,
group_range
in
enumerate
(
opt_group_ranges
):
for
group_index
,
group_range
in
enumerate
(
opt_group_ranges
):
# Params of this group.
# Params of this group.
ful
l_float16_params_this_group
=
[]
mode
l_float16_params_this_group
=
[]
ful
l_fp32_params_this_group
=
[]
mode
l_fp32_params_this_group
=
[]
shard_float16_params_this_group
=
[]
shard_float16_params_this_group
=
[]
shard_fp32_params_this_group
=
[]
shard_fp32_params_this_group
=
[]
shard_fp32_from_float16_params_this_group
=
[]
shard_fp32_from_float16_params_this_group
=
[]
ful
l_float16_groups
.
append
(
ful
l_float16_params_this_group
)
mode
l_float16_groups
.
append
(
mode
l_float16_params_this_group
)
ful
l_fp32_groups
.
append
(
ful
l_fp32_params_this_group
)
mode
l_fp32_groups
.
append
(
mode
l_fp32_params_this_group
)
shard_float16_groups
.
append
(
shard_float16_params_this_group
)
shard_float16_groups
.
append
(
shard_float16_params_this_group
)
shard_fp32_groups
.
append
(
shard_fp32_params_this_group
)
shard_fp32_groups
.
append
(
shard_fp32_params_this_group
)
shard_fp32_from_float16_groups
.
append
(
shard_fp32_from_float16_groups
.
append
(
...
@@ -251,7 +253,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -251,7 +253,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_main_param
.
shared
=
model_param
.
shared
shard_main_param
.
shared
=
model_param
.
shared
# Add to group.
# Add to group.
ful
l_float16_params_this_group
.
append
(
model_param
)
mode
l_float16_params_this_group
.
append
(
model_param
)
shard_float16_params_this_group
.
append
(
shard_model_param
)
shard_float16_params_this_group
.
append
(
shard_model_param
)
shard_fp32_from_float16_params_this_group
.
append
(
shard_main_param
)
shard_fp32_from_float16_params_this_group
.
append
(
shard_main_param
)
...
@@ -259,7 +261,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -259,7 +261,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
elif
model_param
.
type
()
==
'torch.cuda.FloatTensor'
:
elif
model_param
.
type
()
==
'torch.cuda.FloatTensor'
:
shard_model_param
=
model_param
.
view
(
-
1
)
\
shard_model_param
=
model_param
.
view
(
-
1
)
\
[
param_range
.
start
:
param_range
.
end
]
[
param_range
.
start
:
param_range
.
end
]
ful
l_fp32_params_this_group
.
append
(
model_param
)
mode
l_fp32_params_this_group
.
append
(
model_param
)
shard_fp32_params_this_group
.
append
(
shard_model_param
)
shard_fp32_params_this_group
.
append
(
shard_model_param
)
mpu
.
copy_tensor_model_parallel_attributes
(
mpu
.
copy_tensor_model_parallel_attributes
(
shard_model_param
,
model_param
)
shard_model_param
,
model_param
)
...
@@ -280,8 +282,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -280,8 +282,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
]
]
return
(
return
(
ful
l_float16_groups
,
mode
l_float16_groups
,
ful
l_fp32_groups
,
mode
l_fp32_groups
,
shard_float16_groups
,
shard_float16_groups
,
shard_fp32_groups
,
shard_fp32_groups
,
shard_fp32_from_float16_groups
,
shard_fp32_from_float16_groups
,
...
@@ -315,8 +317,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -315,8 +317,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Allocate main param shards.
# Allocate main param shards.
(
(
self
.
ful
l_float16_groups
,
self
.
mode
l_float16_groups
,
self
.
ful
l_fp32_groups
,
self
.
mode
l_fp32_groups
,
self
.
shard_float16_groups
,
self
.
shard_float16_groups
,
self
.
shard_fp32_groups
,
self
.
shard_fp32_groups
,
self
.
shard_fp32_from_float16_groups
,
self
.
shard_fp32_from_float16_groups
,
...
@@ -333,6 +335,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -333,6 +335,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def
get_model_param_range_map
(
self
,
param
):
def
get_model_param_range_map
(
self
,
param
):
'''
Given a model param, get the index sub-range of the param that this
data-parallel rank owns.
'''
model_index
,
dtype
=
self
.
model_param_gbuf_map
[
param
]
model_index
,
dtype
=
self
.
model_param_gbuf_map
[
param
]
gbuf_range_map
=
self
.
model_gbuf_ranges
[
model_index
][
dtype
]
gbuf_range_map
=
self
.
model_gbuf_ranges
[
model_index
][
dtype
]
param_range_map
=
gbuf_range_map
[
"param_map"
][
param
]
param_range_map
=
gbuf_range_map
[
"param_map"
][
param
]
...
@@ -390,8 +396,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -390,8 +396,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
fragmentation; in the case of set_to_none==True, the space
fragmentation; in the case of set_to_none==True, the space
used by this field can be safely deallocated at this point."""
used by this field can be safely deallocated at this point."""
for
groups
in
(
for
groups
in
(
self
.
ful
l_float16_groups
,
self
.
mode
l_float16_groups
,
self
.
ful
l_fp32_groups
,
self
.
mode
l_fp32_groups
,
self
.
shard_float16_groups
,
# grad empty/unused here?
self
.
shard_float16_groups
,
# grad empty/unused here?
self
.
shard_fp32_groups
,
# throws grad-access warning
self
.
shard_fp32_groups
,
# throws grad-access warning
self
.
shard_fp32_from_float16_groups
):
self
.
shard_fp32_from_float16_groups
):
...
@@ -502,46 +508,46 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -502,46 +508,46 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def
_copy_model_grads_to_main_grads
(
self
):
def
_copy_model_grads_to_main_grads
(
self
):
def
copy_group_grads
(
full_
model_groups
,
shard_main_groups
):
def
copy_group_grads
(
model_groups
,
shard_main_groups
):
for
full_
model_group
,
shard_main_group
in
zip
(
full_
model_groups
,
for
model_group
,
shard_main_group
in
zip
(
model_groups
,
shard_main_groups
):
shard_main_groups
):
for
full_
model_param
,
shard_main_param
in
zip
(
full_
model_group
,
for
model_param
,
shard_main_param
in
zip
(
model_group
,
shard_main_group
):
shard_main_group
):
param_range_map
=
self
.
get_model_param_range_map
(
full_
model_param
)
param_range_map
=
self
.
get_model_param_range_map
(
model_param
)
param_range
=
param_range_map
[
"param"
]
param_range
=
param_range_map
[
"param"
]
assert
param_range
.
size
==
shard_main_param
.
nelement
()
assert
param_range
.
size
==
shard_main_param
.
nelement
()
full_
model_grad
=
full_
model_param
.
main_grad
model_grad
=
model_param
.
main_grad
shard_model_grad
=
full_
model_grad
.
view
(
-
1
)
\
shard_model_grad
=
model_grad
.
view
(
-
1
)
\
[
param_range
.
start
:
param_range
.
end
]
[
param_range
.
start
:
param_range
.
end
]
shard_main_param
.
grad
=
shard_model_grad
.
float
()
shard_main_param
.
grad
=
shard_model_grad
.
float
()
copy_group_grads
(
self
.
ful
l_float16_groups
,
copy_group_grads
(
self
.
mode
l_float16_groups
,
self
.
shard_fp32_from_float16_groups
)
self
.
shard_fp32_from_float16_groups
)
copy_group_grads
(
self
.
ful
l_fp32_groups
,
copy_group_grads
(
self
.
mode
l_fp32_groups
,
self
.
shard_fp32_groups
)
self
.
shard_fp32_groups
)
def
_copy_main_params_to_model_params
(
self
):
def
_copy_main_params_to_model_params
(
self
):
def
copy_group_params
(
shard_main_groups
,
full_
model_groups
):
def
copy_group_params
(
shard_main_groups
,
model_groups
):
for
shard_main_group
,
full_
model_group
in
zip
(
shard_main_groups
,
for
shard_main_group
,
model_group
in
zip
(
shard_main_groups
,
full_
model_groups
):
model_groups
):
for
shard_main_param
,
full_
model_param
in
zip
(
shard_main_group
,
for
shard_main_param
,
model_param
in
zip
(
shard_main_group
,
full_
model_group
):
model_group
):
param_range_map
=
self
.
get_model_param_range_map
(
full_
model_param
)
param_range_map
=
self
.
get_model_param_range_map
(
model_param
)
param_range
=
param_range_map
[
"param"
]
param_range
=
param_range_map
[
"param"
]
assert
param_range
.
size
==
shard_main_param
.
nelement
()
assert
param_range
.
size
==
shard_main_param
.
nelement
()
full_
model_grad
=
full_
model_param
.
main_grad
model_grad
=
model_param
.
main_grad
shard_model_grad
=
full_
model_grad
.
view
(
-
1
)
\
shard_model_grad
=
model_grad
.
view
(
-
1
)
\
[
param_range
.
start
:
param_range
.
end
]
[
param_range
.
start
:
param_range
.
end
]
shard_model_grad
.
data
.
copy_
(
shard_main_param
)
shard_model_grad
.
data
.
copy_
(
shard_main_param
)
copy_group_params
(
self
.
shard_fp32_from_float16_groups
,
copy_group_params
(
self
.
shard_fp32_from_float16_groups
,
self
.
ful
l_float16_groups
)
self
.
mode
l_float16_groups
)
copy_group_params
(
self
.
shard_fp32_groups
,
copy_group_params
(
self
.
shard_fp32_groups
,
self
.
ful
l_fp32_groups
)
self
.
mode
l_fp32_groups
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment