Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
57bfa7ca
Commit
57bfa7ca
authored
Sep 30, 2022
by
Lawrence McAfee
Committed by
Jared Casper
Sep 30, 2022
Browse files
Perform distributed optimizer's all-gather in param dtype (instead of grad dtype)
parent
1afe3541
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
78 additions
and
43 deletions
+78
-43
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+1
-0
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+72
-40
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+5
-3
No files found.
megatron/optimizer/__init__.py
View file @
57bfa7ca
...
@@ -145,6 +145,7 @@ def get_megatron_optimizer(model,
...
@@ -145,6 +145,7 @@ def get_megatron_optimizer(model,
args
.
use_contiguous_buffers_in_local_ddp
,
args
.
use_contiguous_buffers_in_local_ddp
,
args
.
fp16
,
args
.
fp16
,
args
.
bf16
,
args
.
bf16
,
args
.
params_dtype
,
grad_scaler
,
grad_scaler
,
model
)
model
)
...
...
megatron/optimizer/distrib_optimizer.py
View file @
57bfa7ca
...
@@ -351,7 +351,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -351,7 +351,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
fp16
,
bf16
,
grad_scaler
,
models
):
fp16
,
bf16
,
params_dtype
,
grad_scaler
,
models
):
"""
"""
See top of class definition for argument descriptions.
See top of class definition for argument descriptions.
...
@@ -365,7 +365,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -365,7 +365,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
super
().
__init__
(
super
().
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
fp16
,
bf16
,
grad_scaler
,
models
)
fp16
,
bf16
,
params_dtype
,
grad_scaler
,
models
)
# Verify that contiguous buffers are being used.
# Verify that contiguous buffers are being used.
# - Note: this should already be checked in arguments.py.
# - Note: this should already be checked in arguments.py.
...
@@ -394,6 +394,21 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -394,6 +394,21 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self
.
model_param_gbuf_map
,
self
.
model_param_gbuf_map
,
self
.
opt_group_ranges
)
self
.
opt_group_ranges
)
# Initialize param buffers.
# - These are views on the DDP model's grad buffers, that share
# storage & have their own dtype. This is safe because the param
# dtype size is always <= grad dtype size.
self
.
param_buffers
=
[]
for
model_index
,
model
in
enumerate
(
self
.
models
):
current_param_buffers
=
{}
for
dtype
,
grad_buffer
in
model
.
_grad_buffers
.
items
():
param_buffer
=
torch
.
tensor
(
grad_buffer
.
data
.
storage
().
_untyped
(),
dtype
=
params_dtype
,
device
=
grad_buffer
.
data
.
device
)
param_buffer
=
param_buffer
[:
grad_buffer
.
numel_padded
]
current_param_buffers
[
dtype
]
=
param_buffer
self
.
param_buffers
.
append
(
current_param_buffers
)
# Update optimizer groups.
# Update optimizer groups.
# - Also, leverage state_dict() and load_state_dict() to
# - Also, leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors.
# recast preexisting per-param state tensors.
...
@@ -488,36 +503,48 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -488,36 +503,48 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
_zero_grad_group_helper
(
group
,
set_to_none
)
_zero_grad_group_helper
(
group
,
set_to_none
)
def
get_model_grad_buffer_dp_views
(
self
):
@
staticmethod
def
get_model_buffer_dp_views
(
model_buffers
):
"""
"""
Get shard views of each of the DDP's grad buffers.
Get shard views of each of the DDP's
param/
grad buffers.
In this nested list, the top level is grouped by the virtual model
In this nested list, the top level is grouped by the virtual model
index and the
grad
buffer's data type. The sub-level is a list of
index and the buffer's data type. The sub-level is a list of
shards of that
grad
buffer, where each shard in the list represents
shards of that buffer, where each shard in the list represents
a contiguous view of the
grad
buffer, that is owned by a data-parallel
a contiguous view of the buffer, that is owned by a data-parallel
rank. The shard boundary does not respect parameter boundaries, and
rank. The shard boundary does not respect parameter boundaries, and
so the elements of some parameters are split across data parallel
so the elements of some parameters are split across data parallel
ranks.
ranks.
Additionally, return references to the entire
grad
buffers, for use
Additionally, return references to the entire buffers, for use
in _reduce_scatter_base and _all_gather_base.
in _reduce_scatter_base and _all_gather_base.
"""
"""
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
#
Grad b
uffer views.
#
B
uffer views.
gbuf_
view_items
=
[]
view_items
=
[]
for
model_index
,
model
in
enumerate
(
self
.
models
):
for
model_index
,
buffers
in
enumerate
(
model
_buffer
s
):
for
dtype
,
g
buf
in
model
.
_grad_
buffers
.
items
():
for
dtype
,
buf
in
buffers
.
items
():
assert
g
buf
.
numel
_padded
%
data_parallel_world_size
==
0
assert
buf
.
numel
()
%
data_parallel_world_size
==
0
shard_size
=
int
(
g
buf
.
numel
_padded
/
data_parallel_world_size
)
shard_size
=
int
(
buf
.
numel
()
/
data_parallel_world_size
)
g
buf_views
=
[
g
buf
.
data
[(
r
*
shard_size
):((
r
+
1
)
*
shard_size
)]
buf_views
=
[
buf
[(
r
*
shard_size
):((
r
+
1
)
*
shard_size
)]
for
r
in
range
(
data_parallel_world_size
)]
for
r
in
range
(
data_parallel_world_size
)]
gbuf_view_items
.
append
((
model_index
,
dtype
,
gbuf
.
data
,
gbuf_views
))
view_items
.
append
((
model_index
,
dtype
,
buf
,
buf_views
))
return
view_items
return
gbuf_view_items
def
get_model_grad_buffer_dp_views
(
self
):
return
self
.
get_model_buffer_dp_views
([
{
dtype
:
mem_buffer
.
data
}
for
model
in
self
.
models
for
dtype
,
mem_buffer
in
model
.
_grad_buffers
.
items
()])
def
get_model_param_buffer_dp_views
(
self
):
return
self
.
get_model_buffer_dp_views
(
self
.
param_buffers
)
def
reduce_model_grads
(
self
,
args
,
timers
):
def
reduce_model_grads
(
self
,
args
,
timers
):
...
@@ -574,9 +601,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -574,9 +601,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
"""
"""
All-gather updated model params.
All-gather updated model params.
The DDP's
g
ra
d
buffer is used for the all-gather, and thus no
The DDP's
pa
ra
m
buffer is used for the all-gather, and thus no
tensors are dynamically allocated. After the all-gather, the params
tensors are dynamically allocated. After the all-gather, the params
can be copied from param
.main_grad to
param.
can be copied from
the
param
buffer to the
param.
"""
"""
timers
(
'params-all-gather'
,
log_level
=
1
).
start
(
timers
(
'params-all-gather'
,
log_level
=
1
).
start
(
...
@@ -586,26 +613,28 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -586,26 +613,28 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
data_parallel_group
=
mpu
.
get_data_parallel_group
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
# All-gather updated main params.
# All-gather updated main params.
# - All grad buffer views are guaranteed to have the same num elements
# - All param buffer views are guaranteed to have the same num elements
# across all data parallel ranks, with grad buffer padding that is done
# across all data parallel ranks, due to grad buffer padding that is
# in distributed.py. Thus, all sub-views will have consistent start/end
# done in distributed.py, and extended to the param buffers. Thus,
# indexes across data parallel ranks.
# all sub-views will have consistent start/end indexes across data
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
# parallel ranks.
for
index
,
(
model_index
,
dtype
,
gbuf
,
gbuf_views
)
\
pbuf_view_items
=
self
.
get_model_param_buffer_dp_views
()
in
enumerate
(
gbuf_view_items
):
for
index
,
(
model_index
,
dtype
,
pbuf
,
pbuf_views
)
\
in
enumerate
(
pbuf_view_items
):
torch
.
distributed
.
_all_gather_base
(
torch
.
distributed
.
_all_gather_base
(
g
buf
,
p
buf
,
g
buf_views
[
data_parallel_rank
],
p
buf_views
[
data_parallel_rank
],
group
=
data_parallel_group
,
group
=
data_parallel_group
,
)
)
# Each model param now contains its updated values in its
# Copy from param buffer to each param.
# '.main_grad' field.
for
model_id
,
model
in
enumerate
(
self
.
models
):
for
model
in
self
.
models
:
for
dtype
,
param_map
in
model
.
_grad_buffer_param_index_map
.
items
():
for
dtype
,
param_map
in
model
.
_grad_buffer_param_index_map
.
items
():
for
param
in
param_map
:
for
param
,
buf_range
in
param_map
.
items
():
param
.
detach
().
copy_
(
param
.
main_grad
)
param_buf
=
self
.
param_buffers
[
model_id
][
dtype
]
param_buf_shard
=
param_buf
[
buf_range
[
0
]:
buf_range
[
1
]]
param
.
view
(
-
1
).
detach
().
copy_
(
param_buf_shard
)
timers
(
'params-all-gather'
).
stop
()
timers
(
'params-all-gather'
).
stop
()
...
@@ -685,14 +714,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -685,14 +714,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
model_group
):
model_group
):
param_range_map
=
self
.
get_model_param_range_map
(
model_param
)
param_range_map
=
self
.
get_model_param_range_map
(
model_param
)
param_range
=
param_range_map
[
"param"
]
world_range
=
param_range_map
[
"gbuf_world"
]
assert
param_range
.
size
==
shard_main_param
.
nelement
()
model_grad
=
model_param
.
main_grad
assert
world_range
.
size
==
shard_main_param
.
nelement
()
shard_model_grad
=
model_grad
.
view
(
-
1
)
\
[
param_range
.
start
:
param_range
.
end
]
model_id
,
dtype
=
self
.
model_param_gbuf_map
[
model_param
]
model_param_buffer
=
self
.
param_buffers
[
model_id
][
dtype
]
shard_model_param
=
model_param_buffer
.
view
(
-
1
)
\
[
world_range
.
start
:
world_range
.
end
]
shard_model_
g
ra
d
.
data
.
copy_
(
shard_main_param
)
shard_model_
pa
ra
m
.
data
.
copy_
(
shard_main_param
)
# Copy shard groups to model groups.
# Copy shard groups to model groups.
copy_group_params
(
self
.
shard_fp32_from_float16_groups
,
copy_group_params
(
self
.
shard_fp32_from_float16_groups
,
...
...
megatron/optimizer/optimizer.py
View file @
57bfa7ca
...
@@ -335,6 +335,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
...
@@ -335,6 +335,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
is using a contiguous buffer to hold the model grads.
is using a contiguous buffer to hold the model grads.
fp16: if true, the model is running in fp16.
fp16: if true, the model is running in fp16.
bf16: if true, the model is running in bfloat16.
bf16: if true, the model is running in bfloat16.
params_dtype: used by distributed optimizer.
grad_scaler: used for scaling gradients. Note that this can be
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
use any loss scale. Note that for `bf16 = True`, we can have
...
@@ -346,7 +347,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
...
@@ -346,7 +347,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
fp16
,
bf16
,
grad_scaler
,
fp16
,
bf16
,
params_dtype
,
grad_scaler
,
models
):
models
):
super
().
__init__
(
super
().
__init__
(
...
@@ -356,6 +357,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
...
@@ -356,6 +357,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
self
.
fp16
=
fp16
self
.
fp16
=
fp16
self
.
bf16
=
bf16
self
.
bf16
=
bf16
self
.
params_dtype
=
params_dtype
self
.
grad_scaler
=
grad_scaler
self
.
grad_scaler
=
grad_scaler
# None grad scaler is only supported for bf16.
# None grad scaler is only supported for bf16.
...
@@ -505,12 +507,12 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
...
@@ -505,12 +507,12 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
fp16
,
bf16
,
grad_scaler
,
models
):
fp16
,
bf16
,
params_dtype
,
grad_scaler
,
models
):
super
().
__init__
(
super
().
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
fp16
,
bf16
,
grad_scaler
,
models
)
fp16
,
bf16
,
params_dtype
,
grad_scaler
,
models
)
# ======================
# ======================
# main parameter stuff
# main parameter stuff
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment