Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
5ffb22d0
Unverified
Commit
5ffb22d0
authored
Jun 14, 2022
by
Thor Johnsen
Committed by
GitHub
Jun 14, 2022
Browse files
Merge pull request #1401 from timmoon10/dist-adam-zero
ZeRO-2 support in DistributedFusedAdam
parents
265b451d
846f7f8a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
925 additions
and
1667 deletions
+925
-1667
apex/contrib/optimizers/distributed_fused_adam.py
apex/contrib/optimizers/distributed_fused_adam.py
+824
-592
apex/contrib/optimizers/distributed_fused_adam_v2.py
apex/contrib/optimizers/distributed_fused_adam_v2.py
+0
-615
apex/contrib/optimizers/distributed_fused_adam_v3.py
apex/contrib/optimizers/distributed_fused_adam_v3.py
+0
-325
tests/L0/run_optimizers/test_dist_adam.py
tests/L0/run_optimizers/test_dist_adam.py
+101
-135
No files found.
apex/contrib/optimizers/distributed_fused_adam.py
View file @
5ffb22d0
import
collections
import
contextlib
import
enum
import
importlib
import
inspect
import
math
import
math
import
threading
import
torch
import
torch
import
importlib
import
amp_C
import
amp_C
from
apex.multi_tensor_apply
import
multi_tensor_applier
from
apex.multi_tensor_apply
import
multi_tensor_applier
from
torch.distributed.distributed_c10d
import
_get_default_group
import
torch.distributed.distributed_c10d
as
c10d
class
DistributedFusedAdam
(
torch
.
optim
.
Optimizer
):
class
DistributedFusedAdam
(
torch
.
optim
.
Optimizer
):
"""AdamW optimizer with ZeRO algorithm.
"""Implements Adam algorithm.
Currently GPU-only.
Requires Apex to be installed via
Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
``python setup.py install --cuda_ext --cpp_ext``.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
This implements the ZeRO-2 algorithm, which distributes the
optimizer state and gradients between parallel processes. In
particular, the parameters are flattened, grouped into fixed-size
buckets, and the optimizer state for each bucket is sharded over
the parallel processes. Options are provided to overlap the
gradient synchronization with the backward pass compute.
Adam was proposed in `Adam: A Method for Stochastic
Optimization`_, AdamW in `Decoupled Weight Decay Regularization`_,
and ZeRO in `ZeRO: Memory Optimizations Toward Training Trillion
Parameter Models`_.
Arguments:
Arguments:
params (iterable): iterable of parameters to optimize or dicts
defining
params (iterable): iterable of parameters to optimize or dicts
parameter groups.
defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
betas (Tuple[float, float], optional): coefficients used for
running averages of gradient and its square. (default: (0.9, 0.999))
computing running averages of gradient and its square.
(default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
numerical stability. (default: 1e-8)
eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
weight_decay (float, optional): weight decay (L2 penalty)
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
overlap_reductions(boolean, optional): whether to overlap reductions
with bprop (default: True)
step_supports_amp_scaling(boolean, optional): whether to use customized
gradient unscaling logic (default: True)
num_process_groups (integer, optional): number of process groups in
the app (default: 1)
current_process_group (object, optional): the process group to work on
(default: None)
process_group_id (integer, optional): process group id (default: 0)
process_group_size (integer, optional): size of process group
(default: 0)
(default: 0)
clip_grad_norm (boolean, optional): whether to handle gradient clipping
amsgrad (boolean, optional): whether to use the AMSGrad
variant of this algorithm from the paper
`On the Convergence of Adam and Beyond`_ (default: False).
This is not yet supported.
dtype (torch.dtype, optional): datatype for optimizer state
(default: torch.float32)
grad_sync_dtype (torch.dtype, optional): datatype for gradient
synchronization (default: same as dtype)
param_sync_dtype (torch.dtype, optional): datatype for
parameter synchronization (default: same as
grad_sync_dtype)
device (torch.device, optional): device for optimizer state
(default: cuda). Currently only supports GPU.
process_group (torch.distributed.ProcessGroup, optional):
parallel processes participating in optimizer (default:
default group in torch.distributed). This group is
interpreted as a 2D grid with dimensions
distributed_size x redundant_size.
distributed_process_group (torch.distributed.ProcessGroup,
optional): parallel processes to distribute optimizer
state over (default: same as process_group)
redundant_process_group (torch.distributed.ProcessGroup,
optional): parallel processes to replicate optimizer state
over (default: group only containing calling process)
model_parallel (bool, optional): whether model parallelism is
used (default: False)
model_parallel_rank (int, optional): rank in model-parallel
process group (default: 0)
average_grad_sync (bool, optional): whether to use average
reduction for gradient synchronization rather than sum
(default: True)
(default: True)
model_parallel (boolean, optional): whether model parallelism is used
overlap_grad_sync(boolean, optional): whether to overlap
(default: False)
gradient synchronization with backward pass compute
(default: True)
bucket_cap_mb (float, optional): bucket size in megabytes
(default: 15)
pipeline_size (int, optional): number of buckets to
synchronize simultaneously (default: 2)
fused_grad_copy (bool, optional): whether to used fused kernel
to fill bucket with gradients (default: False). Requires
all parameters to have the same data type.
max_grad_norm (float, optional): maximum L2 norm for gradient
clipping (default: disabled)
.. _Adam\: A Method for Stochastic Optimization:
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
https://openreview.net/forum?id=ryQu7f-RZ
.. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
.. _ZeRO\: Memory Optimizations Toward Training Trillion Parameter Models:
https://arxiv.org/abs/1910.02054
"""
"""
def
__init__
(
self
,
params
,
class
GradientStatus
(
enum
.
Enum
):
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
"""Status of gradients within a bucket"""
eps
=
1e-8
,
eps_inside_sqrt
=
False
,
# Gradients are ready to use
weight_decay
=
0.
,
max_grad_norm
=
0.
,
READY
=
enum
.
auto
()
amsgrad
=
False
,
flat_mt
=
False
,
# Bucket is partially filled with unreduced gradients
overlap_reductions
=
True
,
PARTIALLY_FILLED
=
enum
.
auto
()
compute_L2_grad_norm
=
False
,
# Bucket is fully filled with unreduced gradients
dwu_group_size
=
0
,
dwu_num_blocks
=
4
,
dwu_num_chunks
=
4
,
FULLY_FILLED
=
enum
.
auto
()
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_num_ag_pg
=
0
,
# Asynchronous reduction is in progress
predivide
=
True
,
e5m2_allgather
=
False
,
SYNCING
=
enum
.
auto
()
do_not_flatten_model
=
False
,
step_supports_amp_scaling
=
True
,
def
__init__
(
self
,
num_process_groups
=
1
,
params
,
current_process_group
=
None
,
lr
=
1e-3
,
process_group_id
=
0
,
bias_correction
=
True
,
process_group_size
=
0
,
betas
=
(
0.9
,
0.999
),
clip_grad_norm
=
True
,
eps
=
1e-8
,
model_parallel
=
False
):
weight_decay
=
0.
,
amsgrad
=
False
,
dtype
=
torch
.
float32
,
grad_sync_dtype
=
None
,
param_sync_dtype
=
None
,
device
=
'cuda'
,
process_group
=
None
,
distributed_process_group
=
None
,
redundant_process_group
=
None
,
model_parallel
=
False
,
model_parallel_rank
=
0
,
average_grad_sync
=
True
,
overlap_grad_sync
=
True
,
bucket_cap_mb
=
15
,
pipeline_size
=
2
,
fused_grad_copy
=
False
,
max_grad_norm
=
0.
,
):
defaults
=
dict
(
lr
=
lr
,
bias_correction
=
bias_correction
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
)
super
(
DistributedFusedAdam
,
self
).
__init__
(
params
,
defaults
)
# Adam options
if
amsgrad
:
raise
RuntimeError
(
'DistributedFusedAdam does not support the AMSGrad variant.'
)
# Datatype options
if
grad_sync_dtype
is
None
:
grad_sync_dtype
=
dtype
if
param_sync_dtype
is
None
:
param_sync_dtype
=
grad_sync_dtype
valid_dtypes
=
[
(
torch
.
float32
,
torch
.
float16
,
torch
.
float16
),
(
torch
.
float32
,
torch
.
float32
,
torch
.
float32
),
]
if
(
dtype
,
grad_sync_dtype
,
param_sync_dtype
)
not
in
valid_dtypes
:
raise
RuntimeError
(
'Invalid dtypes for DistributedFusedAdam '
f
'(dtype=
{
dtype
}
, '
f
'grad_sync_dtype=
{
grad_sync_dtype
}
, '
f
'param_sync_dtype=
{
param_sync_dtype
}
))'
)
if
device
!=
'cuda'
:
raise
RuntimeError
(
'DistributedFusedAdam only supports GPU'
)
self
.
dtype
=
dtype
self
.
grad_sync_dtype
=
grad_sync_dtype
self
.
param_sync_dtype
=
param_sync_dtype
self
.
device
=
device
# Process groups
self
.
world_process_group
=
(
_get_default_group
()
if
process_group
is
None
else
process_group
)
self
.
distributed_process_group
=
(
self
.
world_process_group
if
distributed_process_group
is
None
else
distributed_process_group
)
self
.
redundant_process_group
=
redundant_process_group
self
.
world_size
=
torch
.
distributed
.
get_world_size
(
self
.
world_process_group
)
self
.
distributed_rank
=
torch
.
distributed
.
get_rank
(
self
.
distributed_process_group
)
self
.
distributed_size
=
torch
.
distributed
.
get_world_size
(
self
.
distributed_process_group
)
self
.
redundant_size
=
(
1
if
self
.
redundant_process_group
is
None
else
torch
.
distributed
.
get_world_size
(
self
.
redundant_process_group
)
)
if
(
self
.
world_size
!=
self
.
distributed_size
*
self
.
redundant_size
):
raise
RuntimeError
(
'Invalid process group configuration '
f
'(world process group size =
{
self
.
world_size
}
, '
f
'distributed process group size =
{
self
.
distributed_size
}
, '
f
'redundant process group size =
{
self
.
redundant_size
}
)'
)
self
.
model_parallel
=
model_parallel
self
.
model_parallel_rank
=
model_parallel_rank
# Grad sync options
if
fused_grad_copy
:
_params
=
list
(
self
.
parameters
())
if
(
_params
and
any
(
p
.
dtype
!=
self
.
grad_sync_dtype
for
p
in
_params
)
and
any
(
p
.
device
!=
self
.
device
for
p
in
_params
)):
raise
RuntimeError
(
'Attempted to use fused gradient copy in DistributedFusedAdam, '
'but parameters do not all have expected '
f
'dtype (
{
self
.
grad_sync_dtype
}
) and device (
{
self
.
device
}
)'
)
self
.
average_grad_sync
=
average_grad_sync
self
.
overlap_grad_sync
=
overlap_grad_sync
self
.
pipeline_size
=
pipeline_size
self
.
fused_grad_copy
=
fused_grad_copy
# Grad clipping options
self
.
max_grad_norm
=
max_grad_norm
# Determine bucket sizes
dtype_size
=
torch
.
finfo
(
self
.
grad_sync_dtype
).
bits
//
8
self
.
alignment
=
128
//
dtype_size
bucket_size
=
1024
*
1024
*
bucket_cap_mb
/
dtype_size
shard_size
=
bucket_size
/
self
.
distributed_size
shard_size
=
(
int
(
shard_size
)
//
self
.
alignment
)
*
self
.
alignment
shard_size
=
max
(
shard_size
,
self
.
alignment
)
bucket_size
=
shard_size
*
self
.
distributed_size
self
.
bucket_size
=
bucket_size
self
.
shard_size
=
shard_size
# Load CUDA kernels
global
fused_adam_cuda
,
distributed_adam_cuda
global
fused_adam_cuda
,
distributed_adam_cuda
fused_adam_cuda
=
importlib
.
import_module
(
"fused_adam_cuda"
)
fused_adam_cuda
=
importlib
.
import_module
(
"fused_adam_cuda"
)
distributed_adam_cuda
=
importlib
.
import_module
(
"distributed_adam_cuda"
)
distributed_adam_cuda
=
importlib
.
import_module
(
"distributed_adam_cuda"
)
self
.
multi_tensor_l2norm
=
amp_C
.
multi_tensor_l2norm
if
amsgrad
:
# Optimizer state
raise
RuntimeError
(
'DistributedFusedAdam does not support the AMSGrad variant.'
)
self
.
state
[
'buckets'
]
=
[]
self
.
state
[
'step'
]
=
0
defaults
=
dict
(
lr
=
lr
,
bias_correction
=
bias_correction
,
# Objects for gradient synchronization
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
self
.
_grads_generated
=
set
()
max_grad_norm
=
max_grad_norm
)
self
.
_grads_to_copy
=
[]
super
(
DistributedFusedAdam
,
self
).
__init__
(
params
,
defaults
)
self
.
_pipeline_streams
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
pipeline_size
)]
# Check if collectives have no_copy option
self
.
_reduce_scatter_no_copy
=
(
'no_copy'
in
inspect
.
getfullargspec
(
torch
.
distributed
.
reduce_scatter
).
args
)
self
.
_all_gather_no_copy
=
(
'no_copy'
in
inspect
.
getfullargspec
(
torch
.
distributed
.
all_gather
).
args
)
# Attach hooks for gradient synchronization
self
.
_register_post_backward_hooks
()
def
_register_post_backward_hooks
(
self
):
"""Attach hooks for gradient synchronization
# Misc
Optimizer state for parameters are initialized lazily as they
self
.
eps_mode
=
0
if
eps_inside_sqrt
else
1
are encountered in the backward pass.
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
_has_overflow
=
False
"""
self
.
_step_supports_amp_scaling
=
step_supports_amp_scaling
self
.
_num_grads
=
0
self
.
_last_step
=
False
self
.
_lock
=
threading
.
Lock
()
self
.
_overlap_reductions
=
overlap_reductions
self
.
_global_scale
=
None
self
.
_num_blocks
=
dwu_num_blocks
self
.
_num_chunks
=
dwu_num_chunks
self
.
_predivide
=
predivide
self
.
_e5m2_allgather
=
e5m2_allgather
self
.
_do_not_flatten_model
=
do_not_flatten_model
self
.
_compute_L2_grad_norm
=
compute_L2_grad_norm
self
.
_L2_grad_norm
=
None
self
.
_flat_mt
=
flat_mt
self
.
_init_done
=
False
self
.
_resume_from_checkpoint
=
False
self
.
_step
=
0
# Process group related
self
.
_clip_grad_norm
=
clip_grad_norm
self
.
_model_parallel
=
model_parallel
self
.
_num_process_groups
=
num_process_groups
self
.
_current_process_group
=
current_process_group
if
current_process_group
is
not
None
else
c10d
.
_get_default_group
()
self
.
_available_ranks
=
list
(
c10d
.
_pg_group_ranks
[
self
.
_current_process_group
].
keys
())
self
.
_process_group_id
=
process_group_id
self
.
_process_group_size
=
torch
.
cuda
.
device_count
()
if
process_group_size
<=
0
else
process_group_size
self
.
_world_size
=
self
.
_process_group_size
# world: the current process group
self
.
_group_size
=
torch
.
cuda
.
device_count
()
if
dwu_group_size
<=
0
else
dwu_group_size
self
.
_num_groups
=
self
.
_world_size
//
self
.
_group_size
self
.
_global_rank
=
torch
.
distributed
.
get_rank
()
self
.
_world_rank
=
self
.
_global_rank
//
self
.
_num_process_groups
self
.
_group_rank
=
self
.
_world_rank
%
self
.
_group_size
#print("world_size:", self._world_size, ", group_size:", self._group_size, ", num_groups:", self._num_groups, ", global_rank:", self._global_rank, ", world_rank:", self._world_rank, ", group_rank:", self._group_rank)
self
.
_num_rs_pg
=
dwu_num_rs_pg
self
.
_num_ar_pg
=
dwu_num_ar_pg
self
.
_num_ag_pg
=
dwu_num_ag_pg
# Master weight, moment, gradient buffers
self
.
_fp32_p
,
self
.
_fp32_m
,
self
.
_fp32_v
,
self
.
_fp16_p
,
self
.
_fp16_g
=
None
,
None
,
None
,
None
,
None
def
_first_step_init
(
self
):
p_offset
=
0
p_i
=
0
self
.
_model_params
=
[]
self
.
_grads_info
=
[]
self
.
_grad_accs
=
[]
self
.
_grad_accs
=
[]
self
.
_group_properties
=
[]
for
param_group_id
,
group
in
enumerate
(
self
.
param_groups
):
for
param_id
,
param
in
enumerate
(
group
[
'params'
]):
torch
.
distributed
.
broadcast
(
param
,
src
=
0
,
group
=
self
.
world_process_group
,
)
if
param
.
requires_grad
:
def
wrapper
(
p
,
p_group_id
,
p_id
):
p_tmp
=
p
.
expand_as
(
p
)
grad_acc
=
p_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
def
reduction_hook
(
*
unused
):
with
self
.
_lock
:
if
'fragments'
not
in
self
.
state
[
p
]:
self
.
_init_param_state
(
p
,
p_group_id
,
p_id
)
if
self
.
overlap_grad_sync
:
self
.
_start_grad_copy
(
p
)
self
.
_try_start_bucket_grad_sync
()
grad_acc
.
register_hook
(
reduction_hook
)
self
.
_grad_accs
.
append
(
grad_acc
)
wrapper
(
param
,
param_group_id
,
param_id
)
self
.
_num_grads
+=
1
def
_init_param_state
(
self
,
param
,
param_group_id
,
param_id
,
):
"""Initialize optimizer state for a parameter"""
# Make sure there is at least one bucket
if
not
self
.
state
[
'buckets'
]:
self
.
_add_bucket
()
# Split parameter values into fragments
# Note: Each fragment resides within a bucket
param_start
=
0
param_size
=
param
.
numel
()
self
.
state
[
param
][
'fragments'
]
=
[]
while
param_start
<
param_size
:
# Get current bucket
if
not
self
.
state
[
'buckets'
]:
self
.
_add_bucket
()
bucket_id
=
len
(
self
.
state
[
'buckets'
])
-
1
bucket
=
self
.
state
[
'buckets'
][
bucket_id
]
fragment_id
=
len
(
bucket
[
'fragments'
])
# Determine fragment position within bucket
if
fragment_id
==
0
:
bucket_start
=
0
else
:
bucket_start
=
bucket
[
'fragments'
][
-
1
][
'bucket_range'
][
1
]
bucket_start
=
(
(
bucket_start
+
self
.
alignment
-
1
)
//
self
.
alignment
*
self
.
alignment
)
# Pad until fragment is aligned
fragment_size
=
min
(
param_size
-
param_start
,
self
.
bucket_size
-
bucket_start
)
param_end
=
param_start
+
fragment_size
bucket_end
=
bucket_start
+
fragment_size
# Create new bucket if current one is full
if
fragment_size
<=
0
:
self
.
_add_bucket
()
continue
# Fragment position within local shard
shard_id
=
self
.
distributed_rank
shard_start
=
bucket_start
-
self
.
shard_size
*
shard_id
shard_end
=
bucket_end
-
self
.
shard_size
*
shard_id
shard_start
=
min
(
max
(
shard_start
,
0
),
self
.
shard_size
)
shard_end
=
min
(
max
(
shard_end
,
0
),
self
.
shard_size
)
in_local_shard
=
shard_start
<
shard_end
if
in_local_shard
:
shard_bucket_start
=
shard_start
+
self
.
shard_size
*
shard_id
shard_bucket_end
=
shard_bucket_start
+
shard_end
-
shard_start
shard_param_start
=
shard_bucket_start
-
bucket_start
+
param_start
shard_param_end
=
shard_param_start
+
shard_end
-
shard_start
else
:
shard_bucket_start
,
shard_bucket_end
=
None
,
None
shard_param_start
,
shard_param_end
=
None
,
None
# Record fragment info
fragment
=
{
# Parameter group index
'param_group_id'
:
param_group_id
,
# Parameter index within parameter group
'param_id'
:
param_id
,
# Bucket index
'bucket_id'
:
bucket_id
,
# Range within flattened parameter buffer
'param_range'
:
(
param_start
,
param_end
),
# Range within bucket
'bucket_range'
:
(
bucket_start
,
bucket_end
),
# Whether fragment is in local shard of bucket
'in_local_shard'
:
in_local_shard
,
# Range within local shard
'shard_range'
:
(
shard_start
,
shard_end
),
# Range of local fragment shard within bucket
'shard_bucket_range'
:
(
shard_bucket_start
,
shard_bucket_end
),
# Range of local fragment shard within parameter
'shard_param_range'
:
(
shard_param_start
,
shard_param_end
),
}
# Record fragment info
self
.
state
[
param
][
'fragments'
].
append
(
fragment
)
bucket
[
'fragments'
].
append
(
fragment
)
param_start
=
param_end
# Initialize master param buffer
for
fragment
in
self
.
state
[
param
][
'fragments'
]:
if
fragment
[
'in_local_shard'
]:
bucket_id
=
fragment
[
'bucket_id'
]
bucket
=
self
.
state
[
'buckets'
][
bucket_id
]
param_start
,
param_end
=
fragment
[
'shard_param_range'
]
shard_start
,
shard_end
=
fragment
[
'shard_range'
]
model_param_fragment
=
param
.
view
(
-
1
)[
param_start
:
param_end
]
master_param_fragment
=
bucket
[
'params_shard'
][
shard_start
:
shard_end
]
master_param_fragment
.
copy_
(
model_param_fragment
)
def
_add_bucket
(
self
):
"""Construct a bucket for optimizer state"""
self
.
state
[
'buckets'
].
append
({
# Parameter fragments associated with bucket
'fragments'
:
[],
# Gradient buffers
'grads_shard'
:
None
,
'grads_bucket'
:
None
,
'curr_grads_shard'
:
None
,
# For current micro-batch
# Optimizer state
'params_shard'
:
torch
.
zeros
([
self
.
shard_size
],
dtype
=
self
.
dtype
,
device
=
self
.
device
),
'exp_avg_shard'
:
torch
.
zeros
([
self
.
shard_size
],
dtype
=
self
.
dtype
,
device
=
self
.
device
),
'exp_avg_sq_shard'
:
torch
.
zeros
([
self
.
shard_size
],
dtype
=
self
.
dtype
,
device
=
self
.
device
),
# Status of parameter gradients
'gradient_status'
:
self
.
GradientStatus
.
READY
,
# Distributed request object for gradient synchronization
'grad_sync_request'
:
None
,
})
def
zero_grad
(
self
,
set_to_none
=
True
):
"""Clear parameter gradients"""
for
group
in
self
.
param_groups
:
for
group
in
self
.
param_groups
:
self
.
_param_group
=
group
for
param
in
group
[
'params'
]:
prev
=
None
if
param
.
grad
is
None
or
set_to_none
:
beta1
,
beta2
=
group
[
'betas'
]
param
.
grad
=
None
bias_correction
=
1
if
group
[
'bias_correction'
]
else
0
else
:
eps
=
group
[
'eps'
]
param
.
grad
.
zero_
()
weight_decay
=
group
[
'weight_decay'
]
for
bucket
in
self
.
state
[
'buckets'
]:
for
p
in
group
[
'params'
]:
bucket
[
'grads_shard'
]
=
None
# broadcast from rank 0 of current process group
bucket
[
'grads_bucket'
]
=
None
torch
.
distributed
.
broadcast
(
p
,
src
=
self
.
_available_ranks
[
0
],
group
=
self
.
_current_process_group
)
bucket
[
'curr_grads_shard'
]
=
None
if
not
p
.
requires_grad
:
bucket
[
'gradient_status'
]
=
self
.
GradientStatus
.
READY
continue
self
.
_grads_generated
=
set
()
self
.
_model_params
.
append
(
p
)
# Multiple param groups support:
def
_start_grad_copy
(
self
,
param
):
# store one hyperparam item per parameter tensor
"""Copy parameter gradient to corresponding buckets
self
.
_group_properties
.
append
((
beta1
,
The copy is deferred if using a fused copy kernel.
beta2
,
bias_correction
,
eps
,
weight_decay
))
p_grads_size
=
p
.
numel
()
def
wrapper
(
param
,
param_i
,
param_grads_size
,
param_offset
):
param_tmp
=
param
.
expand_as
(
param
)
grad_acc
=
param_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
def
allreduce_hook
(
*
unused
):
self
.
_do_overlapped_reduction
(
param_i
,
param_grads_size
,
param_offset
,
param
)
grad_acc
.
register_hook
(
allreduce_hook
)
self
.
_grad_accs
.
append
(
grad_acc
)
self
.
_grads_info
.
append
({
"param_grads_size"
:
p_grads_size
,
"param_offset"
:
p_offset
})
wrapper
(
p
,
p_i
,
p_grads_size
,
p_offset
)
p_offset
+=
p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if
prev
is
not
None
and
(
prev
.
data_ptr
()
+
prev
.
numel
()
*
prev
.
element_size
()
!=
p
.
data_ptr
()):
p_offset
=
((
p_offset
+
63
)
//
64
)
*
64
prev
=
p
p_i
+=
1
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
self
.
_grads
=
[]
if
self
.
_overlap_reductions
:
self
.
_current_block
=
self
.
_num_blocks
self
.
_net_total_param_size
=
p_offset
self
.
_total_param_size
=
p_offset
dwu_min_page_size
=
256
*
self
.
_num_blocks
*
self
.
_num_chunks
*
self
.
_group_size
self
.
_total_param_size
=
((
self
.
_total_param_size
+
dwu_min_page_size
-
1
)
//
dwu_min_page_size
)
*
dwu_min_page_size
self
.
_block_size
=
self
.
_total_param_size
//
self
.
_num_blocks
self
.
_chunk_size
=
self
.
_block_size
//
self
.
_num_chunks
self
.
_shard_size
=
self
.
_chunk_size
//
self
.
_group_size
#print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size))
self
.
_low_param_i
=
[
0
]
*
self
.
_num_blocks
for
block_id
in
range
(
self
.
_num_blocks
-
1
,
-
1
,
-
1
):
p_i
=
len
(
self
.
_grads_info
)
-
1
while
p_i
>
0
and
self
.
_grads_info
[
p_i
][
"param_offset"
]
>
block_id
*
self
.
_block_size
:
p_i
-=
1
self
.
_low_param_i
[
block_id
]
=
p_i
#print(self._low_param_i)
self
.
_flat_grads
=
torch
.
zeros
([
self
.
_total_param_size
],
dtype
=
torch
.
float16
,
device
=
'cuda'
)
self
.
_new_params
=
torch
.
zeros
([
self
.
_total_param_size
],
dtype
=
torch
.
uint8
if
self
.
_e5m2_allgather
else
torch
.
float16
,
device
=
'cuda'
)
self
.
_mega_shard_size
=
self
.
_num_blocks
*
self
.
_num_chunks
*
self
.
_shard_size
# initialize master weights, moments buffers if not loaded from checkpoint
if
self
.
_fp32_p
is
None
:
self
.
_fp32_p
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_fp32_m
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_fp32_v
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
# FIXME: Rethink fp16 label since it's either uint8 or fp16
self
.
_fp16_p
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
uint8
if
self
.
_e5m2_allgather
else
torch
.
float16
,
device
=
'cuda'
)
self
.
_fp16_g
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float16
,
device
=
'cuda'
)
self
.
_individual_flat_grads
=
[]
for
p_i
,
(
grads_info
,
p
)
in
enumerate
(
zip
(
self
.
_grads_info
,
self
.
_model_params
)):
self
.
_individual_flat_grads
.
append
(
self
.
_flat_grads
[
grads_info
[
"param_offset"
]:
grads_info
[
"param_offset"
]
+
grads_info
[
"param_grads_size"
]].
view_as
(
p
))
def
_flat_split
(
p
):
def
__blockify
(
p
):
return
[
p
[
block_id
*
self
.
_block_size
:(
block_id
+
1
)
*
self
.
_block_size
]
for
block_id
in
range
(
self
.
_num_blocks
)]
def
__chunkify
(
p
):
return
[
p
[
chunk_id
*
self
.
_chunk_size
:(
chunk_id
+
1
)
*
self
.
_chunk_size
]
for
chunk_id
in
range
(
self
.
_num_chunks
)]
def
__shardify
(
p
):
return
[
p
[
shard_id
*
self
.
_shard_size
:(
shard_id
+
1
)
*
self
.
_shard_size
]
for
shard_id
in
range
(
self
.
_group_size
)]
list_of_blocks
=
__blockify
(
self
.
_flat_grads
)
list_of_list_of_chunks
=
[
__chunkify
(
block
)
for
block
in
list_of_blocks
]
list_of_list_of_list_of_shards
=
[[
__shardify
(
chunk
)
for
chunk
in
chunks
]
for
chunks
in
list_of_list_of_chunks
]
return
list_of_blocks
,
list_of_list_of_chunks
,
list_of_list_of_list_of_shards
self
.
_flat_grads_blocks
,
self
.
_flat_grads_chunks
,
self
.
_flat_grads_shards
=
_flat_split
(
self
.
_flat_grads
)
def
_full_packed_split
(
p
):
def
__shardify
(
p
):
return
[
p
[
mega_shard
*
self
.
_mega_shard_size
:(
mega_shard
+
1
)
*
self
.
_mega_shard_size
]
for
mega_shard
in
range
(
self
.
_group_size
)]
def
__blockify
(
p
):
return
[
p
[
block_id
*
self
.
_num_chunks
*
self
.
_shard_size
:(
block_id
+
1
)
*
self
.
_num_chunks
*
self
.
_shard_size
]
for
block_id
in
range
(
self
.
_num_blocks
)]
def
__chunkify
(
p
):
return
[
p
[
chunk_id
*
self
.
_shard_size
:(
chunk_id
+
1
)
*
self
.
_shard_size
]
for
chunk_id
in
range
(
self
.
_num_chunks
)]
list_of_mega_shards
=
__shardify
(
p
)
list_of_list_of_mega_blocks
=
[
__blockify
(
mega_shard
)
for
mega_shard
in
list_of_mega_shards
]
list_of_list_of_list_of_mega_chunks
=
[[
__chunkify
(
mega_block
)
for
mega_block
in
mega_blocks
]
for
mega_blocks
in
list_of_list_of_mega_blocks
]
return
list_of_mega_shards
,
list_of_list_of_mega_blocks
,
list_of_list_of_list_of_mega_chunks
self
.
_new_params_mega_shards
,
self
.
_new_params_mega_blocks
,
self
.
_new_params_mega_chunks
=
_full_packed_split
(
self
.
_new_params
)
def
_packed_split
(
p
):
def
__packed_blockify
(
p
):
packed_block_size
=
self
.
_num_chunks
*
self
.
_shard_size
return
[
p
[
block_id
*
packed_block_size
:(
block_id
+
1
)
*
packed_block_size
]
for
block_id
in
range
(
self
.
_num_blocks
)]
def
__packed_chunkify
(
p
):
# in the packed format, each chunk contains one shard, so packed_chunk_size == self._shard_size
return
[
p
[
chunk_id
*
self
.
_shard_size
:(
chunk_id
+
1
)
*
self
.
_shard_size
]
for
chunk_id
in
range
(
self
.
_num_chunks
)]
list_of_blocks
=
__packed_blockify
(
p
)
list_of_list_of_chunks
=
[
__packed_chunkify
(
block
)
for
block
in
list_of_blocks
]
return
list_of_blocks
,
list_of_list_of_chunks
self
.
_fp32_p_blocks
,
self
.
_fp32_p_chunks
=
_packed_split
(
self
.
_fp32_p
)
self
.
_fp32_m_blocks
,
self
.
_fp32_m_chunks
=
_packed_split
(
self
.
_fp32_m
)
self
.
_fp32_v_blocks
,
self
.
_fp32_v_chunks
=
_packed_split
(
self
.
_fp32_v
)
self
.
_fp16_p_blocks
,
self
.
_fp16_p_chunks
=
_packed_split
(
self
.
_fp16_p
)
self
.
_fp16_g_blocks
,
self
.
_fp16_g_chunks
=
_packed_split
(
self
.
_fp16_g
)
# This paragraph does two things:
# 1) Copy model parameters into master buffer
# 2) Create tensor lists for unpacking new parameter tensor after all-gather
self
.
_packed_flat_to_model_params
=
[]
self
.
_contrib_tensor_list
=
[]
self
.
_contrib_group_properties
=
[]
self
.
_non_parallel_grads
=
[]
for
shard_id
in
range
(
self
.
_group_size
):
for
block_id
in
range
(
self
.
_num_blocks
):
for
chunk_id
in
range
(
self
.
_num_chunks
):
flat_shard_start
=
(((
block_id
*
self
.
_num_chunks
+
chunk_id
)
*
self
.
_group_size
)
+
shard_id
)
*
self
.
_shard_size
flat_shard_end
=
flat_shard_start
+
self
.
_shard_size
for
(
p
,
grads_info
,
group_props
)
in
zip
(
self
.
_model_params
,
self
.
_grads_info
,
self
.
_group_properties
):
flat_grad_start
=
grads_info
[
"param_offset"
]
flat_grad_end
=
flat_grad_start
+
grads_info
[
"param_grads_size"
]
clipped_start
=
(
lambda
a
,
b
:
a
if
a
>
b
else
b
)(
flat_grad_start
,
flat_shard_start
)
clipped_end
=
(
lambda
a
,
b
:
a
if
a
<
b
else
b
)(
flat_grad_end
,
flat_shard_end
)
if
clipped_start
<
clipped_end
:
grad_offset
=
clipped_start
-
flat_grad_start
grad_length
=
clipped_end
-
clipped_start
shard_offset
=
clipped_start
-
flat_shard_start
model_param_fragment
=
p
.
view
(
-
1
)[
grad_offset
:
grad_offset
+
grad_length
]
new_param_packed_fragment
=
self
.
_new_params_mega_chunks
[
shard_id
][
block_id
][
chunk_id
][
shard_offset
:
shard_offset
+
grad_length
]
self
.
_packed_flat_to_model_params
.
append
(
(
new_param_packed_fragment
,
model_param_fragment
)
)
if
shard_id
==
self
.
_group_rank
:
# copy model parameters into master buffer
master_param_fragment
=
self
.
_fp32_p_chunks
[
block_id
][
chunk_id
][
shard_offset
:
shard_offset
+
grad_length
]
opti_state_m_fragment
=
self
.
_fp32_m_chunks
[
block_id
][
chunk_id
][
shard_offset
:
shard_offset
+
grad_length
]
opti_state_v_fragment
=
self
.
_fp32_v_chunks
[
block_id
][
chunk_id
][
shard_offset
:
shard_offset
+
grad_length
]
opti_state_g_fragment
=
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
][
shard_offset
:
shard_offset
+
grad_length
]
opti_state_p_fragment
=
self
.
_fp16_p_chunks
[
block_id
][
chunk_id
][
shard_offset
:
shard_offset
+
grad_length
]
#print("model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))
if
not
self
.
_resume_from_checkpoint
:
master_param_fragment
.
copy_
(
model_param_fragment
)
self
.
_contrib_group_properties
.
append
(
group_props
)
self
.
_contrib_tensor_list
.
append
((
master_param_fragment
,
opti_state_m_fragment
,
opti_state_v_fragment
,
opti_state_g_fragment
,
opti_state_p_fragment
))
# p, m, v, g, p_copy
if
self
.
_model_parallel
and
hasattr
(
p
,
'model_parallel'
)
and
not
p
.
model_parallel
:
self
.
_non_parallel_grads
.
append
(
opti_state_g_fragment
)
p
,
m
,
v
,
g
,
p_copy
=
list
(
zip
(
*
self
.
_contrib_tensor_list
))
self
.
_contrib_tensor_list
=
[
p
,
m
,
v
,
g
,
p_copy
]
math_type
=
self
.
_fp32_p
.
dtype
beta1
,
beta2
,
bias_correction
,
epsilon
,
decay
=
list
(
zip
(
*
self
.
_contrib_group_properties
))
self
.
_contrib_beta1
=
torch
.
tensor
(
beta1
,
dtype
=
math_type
,
device
=
'cuda'
)
self
.
_contrib_beta2
=
torch
.
tensor
(
beta2
,
dtype
=
math_type
,
device
=
'cuda'
)
self
.
_contrib_bias_correction
=
torch
.
tensor
(
bias_correction
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
self
.
_contrib_epsilon
=
torch
.
tensor
(
epsilon
,
dtype
=
math_type
,
device
=
'cuda'
)
self
.
_contrib_weight_decay
=
torch
.
tensor
(
decay
,
dtype
=
math_type
,
device
=
'cuda'
)
p_in
,
p_out
=
zip
(
*
self
.
_packed_flat_to_model_params
)
self
.
_packed_flat_to_model_params
=
[
p_in
,
p_out
]
if
self
.
_num_groups
>
1
:
self
.
_ar_pg
=
[]
for
i
in
range
(
self
.
_num_process_groups
):
# gather global ranks of all members of the current process group
ranks
=
[
i
+
k
*
self
.
_num_process_groups
for
k
in
range
(
self
.
_process_group_size
)]
for
j
in
range
(
self
.
_group_size
):
ar_idx
=
[
j
+
k
*
self
.
_group_size
for
k
in
range
(
self
.
_num_groups
)]
ar_rank
=
[
ranks
[
k
]
for
k
in
ar_idx
]
#if self._global_rank in ar_rank:
# print("group for all reduce, ranks:", ar_rank)
for
_
in
range
(
self
.
_num_ar_pg
):
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ar_rank
)
if
self
.
_global_rank
in
ar_rank
:
self
.
_ar_pg
.
append
(
grp
)
self
.
_ar_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_ar_pg
)]
for
ar_pg
in
self
.
_ar_pg
:
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
ar_pg
)
self
.
_rs_pg
,
rs_ranks
=
[],[]
for
i
in
range
(
self
.
_num_process_groups
):
ranks
=
[
i
+
k
*
self
.
_num_process_groups
for
k
in
range
(
self
.
_process_group_size
)]
for
j
in
range
(
self
.
_num_groups
):
rs_idx
=
[
j
*
self
.
_group_size
+
k
for
k
in
range
(
self
.
_group_size
)]
rs_rank
=
[
ranks
[
k
]
for
k
in
rs_idx
]
#if self._global_rank in rs_rank:
# print("group for reduce scatter, ranks:", rs_rank)
for
_
in
range
(
self
.
_num_rs_pg
):
grp
=
torch
.
distributed
.
new_group
(
ranks
=
rs_rank
)
if
self
.
_global_rank
in
rs_rank
:
self
.
_rs_pg
.
append
(
grp
)
if
self
.
_compute_L2_grad_norm
:
l2_grad_norm_pg
=
torch
.
distributed
.
new_group
(
ranks
=
rs_rank
)
if
self
.
_global_rank
in
rs_rank
:
self
.
_l2_grad_norm_pg
=
l2_grad_norm_pg
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
self
.
_l2_grad_norm_pg
)
self
.
_rs_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_rs_pg
)]
for
rs_pg
in
self
.
_rs_pg
:
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
rs_pg
)
if
self
.
_num_ag_pg
==
0
:
self
.
_ag_pg
=
self
.
_rs_pg
self
.
_ag_st
=
self
.
_rs_st
self
.
_num_ag_pg
=
self
.
_num_rs_pg
else
:
self
.
_ag_pg
=
[]
for
i
in
range
(
self
.
_num_process_groups
):
ranks
=
[
i
+
k
*
self
.
_num_process_groups
for
k
in
range
(
self
.
_process_group_size
)]
for
j
in
range
(
self
.
_num_groups
):
ag_rank
=
rs_ranks
[
j
]
#if self._global_rank in ag_rank:
# print("group for all gather, ranks:", ag_rank)
for
_
in
range
(
self
.
_num_ag_pg
):
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ag_rank
)
if
self
.
_global_rank
in
ag_rank
:
self
.
_ag_pg
.
append
(
grp
)
self
.
_ag_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_ag_pg
)]
for
ag_pg
in
self
.
_ag_pg
:
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
ag_pg
)
self
.
_l2_grad_norm_st
=
torch
.
cuda
.
Stream
()
if
self
.
_compute_L2_grad_norm
else
None
self
.
_completion_st
=
torch
.
cuda
.
Stream
()
self
.
_reductions_works
=
[
None
]
*
self
.
_num_blocks
self
.
_allgather_works
=
[
None
]
*
self
.
_num_blocks
import
inspect
assert
(
'no_copy'
in
inspect
.
getfullargspec
(
torch
.
distributed
.
reduce_scatter
).
args
),
"This version of c10d does not support no_copy option"
def
_init_everything
(
self
):
if
not
self
.
_init_done
:
self
.
_first_step_init
()
self
.
_init_done
=
True
def
set_last_step
(
self
,
last_step
):
self
.
_last_step
=
last_step
def
_get_flush_block
(
self
):
flush_block
=
[]
if
self
.
_current_block
>
0
and
self
.
_grads_generated
[
self
.
_low_param_i
[
self
.
_current_block
-
1
]]:
num_grads
=
len
(
self
.
_grads_generated
)
contiguous_idx
=
num_grads
while
contiguous_idx
>
0
and
self
.
_grads_generated
[
contiguous_idx
-
1
]:
contiguous_idx
-=
1
if
contiguous_idx
<
num_grads
and
self
.
_grads_info
[
contiguous_idx
][
"param_offset"
]
<=
(
self
.
_current_block
-
1
)
*
self
.
_block_size
:
self
.
_current_block
-=
1
start
=
self
.
_current_block
*
self
.
_block_size
end
=
(
self
.
_current_block
+
1
)
*
self
.
_block_size
flush_block
=
[
start
,
end
]
return
flush_block
def
_pipeline_block_reductions
(
self
,
block_id
):
self
.
_flatten_grad_mt
(
1.0
/
self
.
_world_size
if
self
.
_predivide
else
1.0
)
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works
=
[
None
]
*
self
.
_num_chunks
for
chunk_id
in
range
(
self
.
_num_chunks
):
glob_chunk_id
=
block_id
*
self
.
_num_chunks
+
chunk_id
rs_stream
=
self
.
_rs_st
[
glob_chunk_id
%
self
.
_num_rs_pg
]
rs_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
rs_stream
):
works
[
chunk_id
]
=
torch
.
distributed
.
reduce_scatter
(
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
],
self
.
_flat_grads_shards
[
block_id
][
chunk_id
],
group
=
self
.
_rs_pg
[
glob_chunk_id
%
self
.
_num_rs_pg
],
async_op
=
True
,
no_copy
=
True
)
# Reduction across nodes for each rank
if
self
.
_num_groups
>
1
:
for
chunk_id
in
range
(
self
.
_num_chunks
):
glob_chunk_id
=
block_id
*
self
.
_num_chunks
+
chunk_id
ar_stream
=
self
.
_ar_st
[
glob_chunk_id
%
self
.
_num_ar_pg
]
with
torch
.
cuda
.
stream
(
ar_stream
):
works
[
chunk_id
].
wait
()
works
[
chunk_id
]
=
torch
.
distributed
.
all_reduce
(
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
],
group
=
self
.
_ar_pg
[
glob_chunk_id
%
self
.
_num_ar_pg
],
async_op
=
True
)
self
.
_reductions_works
[
block_id
]
=
works
# Optionally compute L2 grad norm
if
self
.
_compute_L2_grad_norm
and
block_id
==
0
:
with
torch
.
cuda
.
stream
(
self
.
_l2_grad_norm_st
):
for
block_id
in
range
(
self
.
_num_blocks
):
for
chunk_id
in
range
(
self
.
_num_chunks
):
self
.
_reductions_works
[
block_id
][
chunk_id
].
wait
()
# Since the packed format is contiguous after reductions, only one norm is needed
l2_grad_norm_sq
=
torch
.
empty
([
1
],
device
=
'cuda'
)
l2_grad_norm_sq
=
self
.
_fp16_g
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
torch
.
distributed
.
all_reduce
(
l2_grad_norm_sq
,
group
=
self
.
_l2_grad_norm_pg
)
# for model_parallel_rank=0, keep all gradients
# for the rest, subtract non_parallel gradients
if
self
.
_model_parallel
and
self
.
_process_group_id
:
# non zero model_parallel_rank
non_parallel_grad_norm_sq
=
torch
.
zeros
([
1
],
device
=
'cuda'
)
if
len
(
self
.
_non_parallel_grads
):
# non parallel grads exit
non_parallel_grad_norm_sq
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_overflow_buf
,
[
self
.
_non_parallel_grads
],
False
)[
0
]
**
2
torch
.
distributed
.
all_reduce
(
non_parallel_grad_norm_sq
,
group
=
self
.
_l2_grad_norm_pg
)
l2_grad_norm_sq
=
l2_grad_norm_sq
-
non_parallel_grad_norm_sq
self
.
_L2_grad_norm
=
l2_grad_norm_sq
.
sqrt
().
item
()
def
__launch_step_kernel
(
self
):
# If self._clip_grad_norm is False, we assume gradient clipping already
# happened outside the optimizer and self._global_scale has already
# been set to the combined scale, i.e. it's no longer the current loss
# scale used by the loss scaler.
# For model parallelism cases in which we need to get global gradient
# norm via all-reduce outside the optimizer to do the clipping.
combined_scale
=
self
.
_global_scale
if
self
.
_clip_grad_norm
and
self
.
_param_group
[
'max_grad_norm'
]
>
0
and
math
.
isfinite
(
self
.
L2_grad_norm
):
combined_scale
=
self
.
_param_group
[
'max_grad_norm'
]
/
(
self
.
L2_grad_norm
/
self
.
_global_scale
+
1e-6
)
combined_scale
=
self
.
_global_scale
/
min
(
1
,
combined_scale
)
self
.
_step
+=
1
multi_tensor_applier
(
distributed_adam_cuda
.
multi_tensor_fused_adam
,
self
.
_overflow_buf
,
self
.
_contrib_tensor_list
,
# p, m, v, g, p_copy
self
.
_contrib_beta1
,
self
.
_contrib_beta2
,
self
.
_contrib_bias_correction
,
self
.
_contrib_epsilon
,
self
.
_contrib_weight_decay
,
self
.
_param_group
[
'lr'
],
combined_scale
,
self
.
_step
,
self
.
eps_mode
)
def
_pipeline_step
(
self
):
# Call step kernel once per step
# Call all-gather once per step
with
torch
.
cuda
.
stream
(
self
.
_completion_st
):
for
block_id
in
range
(
self
.
_num_blocks
):
for
chunk_id
in
range
(
self
.
_num_chunks
):
self
.
_reductions_works
[
block_id
][
chunk_id
].
wait
()
self
.
__launch_step_kernel
()
torch
.
distributed
.
all_gather
(
self
.
_new_params_mega_shards
,
self
.
_fp16_p
,
group
=
self
.
_ag_pg
[
0
],
no_copy
=
True
)
def
_flatten_grad_mt
(
self
,
scale
):
if
self
.
_flat_mt
and
len
(
self
.
_grads
)
>
0
:
self
.
_overflow_buf
.
zero_
()
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
self
.
_overflow_buf
,
list
(
zip
(
*
self
.
_grads
)),
scale
)
self
.
_grads
=
[]
def
_do_overlapped_reduction
(
self
,
param_i
,
param_grads_size
,
param_offset
,
param
):
# handle overlapped reductions
if
self
.
_flat_mt
:
self
.
_grads
.
append
(
(
param
.
grad
,
self
.
_individual_flat_grads
[
param_i
])
)
else
:
torch
.
div
(
param
.
grad
,
self
.
_world_size
if
self
.
_predivide
else
1.0
,
out
=
self
.
_individual_flat_grads
[
param_i
])
self
.
_grads_generated
[
param_i
]
=
True
if
not
self
.
_last_step
:
if
self
.
_overlap_reductions
:
flush_block
=
self
.
_get_flush_block
()
while
flush_block
:
block_id
=
flush_block
[
0
]
//
self
.
_block_size
self
.
_pipeline_block_reductions
(
block_id
)
flush_block
=
self
.
_get_flush_block
()
def
set_global_scale
(
self
,
global_scale
):
"""Set global scale.
"""
"""
self
.
_global_scale
=
global_scale
@
property
# Copy param grad to buckets
def
global_scale
(
self
):
for
fragment
in
self
.
state
[
param
][
'fragments'
]:
return
self
.
_global_scale
# Get fragment position
bucket_id
=
fragment
[
'bucket_id'
]
bucket
=
self
.
state
[
'buckets'
][
bucket_id
]
grad_start
,
grad_end
=
fragment
[
'param_range'
]
bucket_start
,
bucket_end
=
fragment
[
'bucket_range'
]
# Set reduction status
if
bucket
[
'gradient_status'
]
==
self
.
GradientStatus
.
SYNCING
:
self
.
_finish_bucket_grad_sync
()
bucket
[
'gradient_status'
]
=
self
.
GradientStatus
.
PARTIALLY_FILLED
# Allocate gradient buffer if needed
if
bucket
[
'grads_bucket'
]
is
None
:
bucket
[
'grads_bucket'
]
=
torch
.
zeros
(
[
self
.
bucket_size
],
dtype
=
self
.
grad_sync_dtype
,
device
=
self
.
device
,
)
# Copy param grad to bucket
if
param
.
grad
is
not
None
:
fragment_in
=
param
.
grad
.
view
(
-
1
)[
grad_start
:
grad_end
]
fragment_out
=
bucket
[
'grads_bucket'
][
bucket_start
:
bucket_end
]
self
.
_grads_to_copy
.
append
((
fragment_in
,
fragment_out
))
# Free param grad buffer
if
not
self
.
fused_grad_copy
:
self
.
_finish_grad_copy
()
param
.
grad
=
None
# Update reduction statuses
self
.
_grads_generated
.
add
(
param
)
for
fragment
in
self
.
state
[
param
][
'fragments'
]:
bucket_id
=
fragment
[
'bucket_id'
]
bucket
=
self
.
state
[
'buckets'
][
bucket_id
]
is_filled
=
True
for
other_fragment
in
reversed
(
bucket
[
'fragments'
]):
param_group_id
=
other_fragment
[
'param_group_id'
]
param_id
=
other_fragment
[
'param_id'
]
other_param
=
self
.
param_groups
[
param_group_id
][
'params'
][
param_id
]
if
other_param
not
in
self
.
_grads_generated
:
is_filled
=
False
break
if
is_filled
:
bucket
[
'gradient_status'
]
=
self
.
GradientStatus
.
FULLY_FILLED
def
_finish_grad_copy
(
self
):
"""Make sure that parameter gradients have been copied to buckets
Performs any deferred copies from _start_grad_copy.
@
property
def
has_overflow
(
self
):
"""Check if overflows were detected by any call to step(...) method.
Clears the overflow flag.
"""
"""
has_overflow
=
self
.
_has_overflow
if
self
.
_grads_to_copy
:
self
.
_has_overflow
=
False
scale
=
1
/
self
.
world_size
if
self
.
average_grad_sync
else
1.0
return
has_overflow
if
self
.
fused_grad_copy
:
dummy_overflow_buf
=
torch
.
zeros
([
1
],
dtype
=
torch
.
int32
,
device
=
'cuda'
)
@
property
multi_tensor_applier
(
def
peek_overflow
(
self
):
amp_C
.
multi_tensor_scale
,
"""Check if overflows were detected by any call to step(...) method.
dummy_overflow_buf
,
Does not clear overflow flag.
list
(
zip
(
*
self
.
_grads_to_copy
)),
"""
scale
,
return
self
.
_has_overflow
)
else
:
for
fragment_in
,
fragment_out
in
self
.
_grads_to_copy
:
fragment_out
.
add_
(
fragment_in
,
alpha
=
scale
)
self
.
_grads_to_copy
=
[]
def
_force_bucket_grad_sync
(
self
):
"""Ensure that all gradient buckets are synchronized"""
# Synchronize all unsynchronized buckets
self
.
_finish_bucket_grad_sync
()
self
.
_start_bucket_grad_sync
([
bucket
for
bucket
in
self
.
state
[
'buckets'
]
if
bucket
[
'gradient_status'
]
!=
self
.
GradientStatus
.
READY
])
self
.
_finish_bucket_grad_sync
()
# Fill any unfilled buckets with zeros
for
bucket
in
self
.
state
[
'buckets'
]:
if
bucket
[
'grads_shard'
]
is
None
:
bucket
[
'grads_shard'
]
=
torch
.
zeros
(
[
self
.
shard_size
],
dtype
=
self
.
grad_sync_dtype
,
device
=
self
.
device
,
)
# Reset set of generated gradients
self
.
_grads_generated
=
set
()
def
_try_start_bucket_grad_sync
(
self
):
"""Launches gradient synchronization if enough buckets are ready
Gradient synchronization is asynchronous. Launches gradient
synchronization if all gradients have been generated or if
there are enough buckets ready to fill pipeline.
def
strided_check_finite
(
self
,
output_params
,
stride
=
1
,
start
=-
1
,
end
=-
1
,
clear
=
True
):
"""Strided check for overflow.
You can get status by calling has_overflow.
"""
"""
if
start
>=
0
and
start
<
end
:
if
len
(
self
.
_grads_generated
)
==
self
.
_num_grads
:
out_p
=
output_params
[
start
:
end
]
self
.
_force_bucket_grad_sync
()
else
:
out_p
=
output_params
fused_adam_cuda
.
strided_check_finite
(
self
.
_overflow_buf
,
out_p
,
stride
,
1
if
clear
else
0
)
self
.
_has_overflow
=
False
if
self
.
_overflow_buf
.
item
()
==
0
else
True
return
self
.
_has_overflow
@
property
def
L2_grad_norm
(
self
):
if
self
.
_compute_L2_grad_norm
:
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_l2_grad_norm_st
)
return
self
.
_L2_grad_norm
else
:
else
:
return
None
filled_buckets
=
[
bucket
for
bucket
in
self
.
state
[
'buckets'
][:
-
1
]
if
bucket
[
'gradient_status'
]
==
self
.
GradientStatus
.
FULLY_FILLED
]
pipeline_size
=
(
len
(
filled_buckets
)
//
self
.
pipeline_size
)
*
self
.
pipeline_size
if
pipeline_size
>
0
:
self
.
_start_bucket_grad_sync
(
filled_buckets
[:
pipeline_size
])
def
_start_bucket_grad_sync
(
self
,
buckets
):
"""Synchronize gradients in buckets
Gradient synchronization is asynchronous. Involves
reduce-scatter over distributed process group and allreduce
over redundant process group.
def
complete_reductions
(
self
):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
"""
self
.
_init_everything
()
self
.
_finish_bucket_grad_sync
()
if
self
.
_last_step
:
self
.
_finish_grad_copy
()
# zero out gradients that have not been completed yet
for
param_i
,
grad_generated
in
enumerate
(
self
.
_grads_generated
):
# Reduce gradients
if
not
grad_generated
:
for
stream
in
self
.
_pipeline_streams
:
grad_info
=
self
.
_grads_info
[
param_i
]
stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
param_offset
=
grad_info
[
"param_offset"
]
for
i
,
bucket
in
enumerate
(
buckets
):
param_size
=
grad_info
[
"param_grads_size"
]
bucket
[
'gradient_status'
]
=
self
.
GradientStatus
.
SYNCING
self
.
_flat_grads
[
param_offset
:
param_offset
+
param_size
].
zero_
()
stream
=
self
.
_pipeline_streams
[
i
%
self
.
pipeline_size
]
self
.
_grads_generated
[
param_i
]
=
True
with
torch
.
cuda
.
stream
(
stream
):
if
self
.
_last_step
or
not
self
.
_overlap_reductions
:
# Reduce-scatter over distributed process group
# nothing done so far, run full pipeline after reductions
if
self
.
distributed_size
==
1
:
for
block_id
in
range
(
self
.
_num_blocks
-
1
,
-
1
,
-
1
):
bucket
[
'curr_grads_shard'
]
=
bucket
[
'grads_bucket'
]
self
.
_pipeline_block_reductions
(
block_id
)
bucket
[
'grad_sync_request'
]
=
None
else
:
if
self
.
_compute_L2_grad_norm
:
bucket
[
'curr_grads_shard'
]
=
torch
.
zeros
(
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_l2_grad_norm_st
)
[
self
.
shard_size
],
dtype
=
self
.
grad_sync_dtype
,
self
.
_current_block
=
self
.
_num_blocks
device
=
self
.
device
,
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
)
grads_bucket_shards
=
[
def
step
(
self
,
closure
=
None
):
bucket
[
'grads_bucket'
][
i
*
self
.
shard_size
:(
i
+
1
)
*
self
.
shard_size
]
loss
=
None
for
i
in
range
(
self
.
distributed_size
)
if
closure
is
not
None
:
]
loss
=
closure
()
if
self
.
_reduce_scatter_no_copy
:
no_copy_kwarg
=
{
'no_copy'
:
True
}
else
:
no_copy_kwarg
=
{}
bucket
[
'grad_sync_request'
]
=
(
torch
.
distributed
.
reduce_scatter
(
bucket
[
'curr_grads_shard'
],
grads_bucket_shards
,
group
=
self
.
distributed_process_group
,
async_op
=
True
,
**
no_copy_kwarg
,
)
)
# All-reduce over redundant process group
# Note: Assuming reduce-scatters are finished in the
# order they are submitted, all-reduces should be
# submitted in a consistent order. There could be race
# conditions if wait doesn't finish in order.
if
self
.
redundant_size
>
1
:
if
bucket
[
'grad_sync_request'
]
is
not
None
:
bucket
[
'grad_sync_request'
].
wait
()
bucket
[
'grad_sync_request'
]
=
(
torch
.
distributed
.
all_reduce
(
bucket
[
'curr_grads_shard'
],
group
=
self
.
redundant_process_group
,
async_op
=
True
,
)
)
def
_finish_bucket_grad_sync
(
self
):
"""Wait for any gradient synchronizations that are in progress"""
for
bucket
in
self
.
state
[
'buckets'
]:
if
bucket
[
'gradient_status'
]
==
self
.
GradientStatus
.
SYNCING
:
# Finish asynchronous communication
if
bucket
[
'grad_sync_request'
]
is
not
None
:
bucket
[
'grad_sync_request'
].
wait
()
bucket
[
'grad_sync_request'
]
=
None
# Accumulate gradient in local shard
if
bucket
[
'grads_shard'
]
is
None
:
bucket
[
'grads_shard'
]
=
bucket
[
'curr_grads_shard'
]
else
:
bucket
[
'grads_shard'
].
add_
(
bucket
[
'curr_grads_shard'
])
# Deallocate buffers for gradient synchronization
bucket
[
'grads_bucket'
]
=
None
bucket
[
'curr_grads_shard'
]
=
None
# Reset status
bucket
[
'gradient_status'
]
=
self
.
GradientStatus
.
READY
@
contextlib
.
contextmanager
def
no_sync
(
self
):
"""Disable overlapped gradient synchronization
Context manager that is similar to
torch.nn.parallel.DistributedDataParallel.no_sync. The
gradients can be synchronized by calling grad_sync or step. If
overlapped gradient synchronization is enabled, gradients can
also be synchronized by leaving the context and performing a
backward pass.
"""
old_overlap_grad_sync
=
self
.
overlap_grad_sync
self
.
overlap_grad_sync
=
False
try
:
yield
finally
:
self
.
overlap_grad_sync
=
old_overlap_grad_sync
def
grad_sync
(
self
):
"""Ensure that all gradients are synchronized"""
for
bucket
in
self
.
state
[
'buckets'
]:
for
fragment
in
bucket
[
'fragments'
]:
param_group_id
=
fragment
[
'param_group_id'
]
param_id
=
fragment
[
'param_id'
]
param
=
self
.
param_groups
[
param_group_id
][
'params'
][
param_id
]
if
param
.
grad
is
not
None
:
self
.
_start_grad_copy
(
param
)
self
.
_try_start_bucket_grad_sync
()
self
.
_force_bucket_grad_sync
()
def
grad_norm
(
self
):
"""Compute L2 norm of all parameter gradients
If model parallelism is enabled, exclude non-parallel
gradients on non-root processes. This is Megatron-specific, so
should this logic be moved elsewhere?
self
.
_pipeline_step
()
"""
with
torch
.
cuda
.
stream
(
self
.
_completion_st
):
# Make sure that gradients have been reduced
# Copy self._new_params to model params
self
.
grad_sync
()
multi_tensor_applier
(
fused_adam_cuda
.
maybe_cast_mt
,
# Evaluate L2 norm of distributed gradients
self
.
_overflow_buf
,
dummy_overflow_buf
=
torch
.
zeros
([
1
],
dtype
=
torch
.
int32
,
device
=
'cuda'
)
self
.
_packed_flat_to_model_params
)
grad_norm_sq
=
multi_tensor_applier
(
amp_C
.
multi_tensor_l2norm
,
dummy_overflow_buf
,
[[
bucket
[
'grads_shard'
]
for
bucket
in
self
.
state
[
'buckets'
]]],
False
,
)[
0
]
**
2
torch
.
distributed
.
all_reduce
(
grad_norm_sq
,
group
=
self
.
distributed_process_group
,
)
# If model parallelism is enabled, subtract non-parallel
# gradients on non-root processes
if
self
.
model_parallel
and
self
.
model_parallel_rank
:
non_parallel_grads
=
[]
for
bucket
in
self
.
state
[
'buckets'
]:
for
fragment
in
bucket
[
'fragments'
]:
if
fragment
[
'in_local_shard'
]:
param_group_id
=
fragment
[
'param_group_id'
]
param_id
=
fragment
[
'param_id'
]
param
=
self
.
param_groups
[
param_group_id
][
'params'
][
param_id
]
if
(
hasattr
(
param
,
'model_parallel'
)
and
not
param
.
model_parallel
):
shard_start
,
shard_end
=
fragment
[
'shard_range'
]
non_parallel_grads
.
append
(
bucket
[
'grads_shard'
][
shard_start
:
shard_end
]
)
if
non_parallel_grads
:
dummy_overflow_buf
=
torch
.
zeros
([
1
],
dtype
=
torch
.
int32
,
device
=
'cuda'
)
non_parallel_grad_norm_sq
=
multi_tensor_applier
(
amp_C
.
multi_tensor_l2norm
,
dummy_overflow_buf
,
[
non_parallel_grads
],
False
,
)[
0
]
**
2
else
:
non_parallel_grad_norm_sq
=
torch
.
zeros
([
1
],
device
=
self
.
device
)
torch
.
distributed
.
all_reduce
(
non_parallel_grad_norm_sq
,
group
=
self
.
distributed_process_group
,
)
grad_norm_sq
-=
non_parallel_grad_norm_sq
return
grad_norm_sq
.
sqrt
()
def
step
(
self
,
closure
=
None
,
scale
=
1.
):
"""Apply Adam optimizer step
Arguments:
closure (callable, optional): closure to recompute loss
(default: None)
scale (float, optional): scaling factor to divide
gradients (default: 1.0)
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_completion_st
)
"""
self
.
state
[
'step'
]
+=
1
loss
=
None
if
closure
is
not
None
:
loss
=
closure
()
self
.
_reductions_works
=
[
None
]
*
self
.
_num_blocks
# Make sure that gradients have been reduced
self
.
_allgather_works
=
[
None
]
*
self
.
_num_blocks
self
.
grad_sync
()
# Scale gradient if L2 norm is too large
if
self
.
max_grad_norm
>
0
:
grad_norm
=
self
.
grad_norm
().
item
()
if
(
math
.
isfinite
(
grad_norm
)
and
grad_norm
/
scale
>
self
.
max_grad_norm
):
scale
=
grad_norm
/
self
.
max_grad_norm
# Apply optimizer step to each bucket and synchronize params
current_stream
=
torch
.
cuda
.
current_stream
()
for
stream
in
self
.
_pipeline_streams
:
stream
.
wait_stream
(
current_stream
)
for
i
,
bucket
in
enumerate
(
self
.
state
[
'buckets'
]):
stream
=
self
.
_pipeline_streams
[
i
%
self
.
pipeline_size
]
with
torch
.
cuda
.
stream
(
stream
):
# Buffer for param sync
params_shard_copy
=
torch
.
zeros
(
[
self
.
shard_size
],
dtype
=
self
.
param_sync_dtype
,
device
=
self
.
device
,
)
# Find param fragments in local shard
buffers
=
collections
.
defaultdict
(
list
)
# p, m, v, g, p_copy
for
fragment
in
bucket
[
'fragments'
]:
if
fragment
[
'in_local_shard'
]:
param_group_id
=
fragment
[
'param_group_id'
]
shard_start
,
shard_end
=
fragment
[
'shard_range'
]
buffers
[
param_group_id
].
append
([
bucket
[
'params_shard'
][
shard_start
:
shard_end
],
bucket
[
'exp_avg_shard'
][
shard_start
:
shard_end
],
bucket
[
'exp_avg_sq_shard'
][
shard_start
:
shard_end
],
bucket
[
'grads_shard'
][
shard_start
:
shard_end
],
params_shard_copy
[
shard_start
:
shard_end
],
])
# Fuse param fragments if possible
if
len
(
buffers
)
==
1
:
group_id
=
list
(
buffers
.
keys
())[
0
]
buffers
[
group_id
]
=
[(
bucket
[
'params_shard'
],
bucket
[
'exp_avg_shard'
],
bucket
[
'exp_avg_sq_shard'
],
bucket
[
'grads_shard'
],
params_shard_copy
,
)]
# Apply optimizer step to each param group
for
group_id
,
group_buffers
in
buffers
.
items
():
# Get param group configs
group
=
self
.
param_groups
[
group_id
]
beta1
,
beta2
=
group
[
'betas'
]
bias_correction
=
1
if
group
[
'bias_correction'
]
else
0
eps
=
group
[
'eps'
]
weight_decay
=
group
[
'weight_decay'
]
# Copy param group configs to GPU
num_fragments
=
len
(
group_buffers
)
beta1
=
torch
.
full
([
num_fragments
],
beta1
,
dtype
=
self
.
dtype
,
device
=
'cuda'
)
beta2
=
torch
.
full
([
num_fragments
],
beta2
,
dtype
=
self
.
dtype
,
device
=
'cuda'
)
bias_correction
=
torch
.
full
([
num_fragments
],
bias_correction
,
dtype
=
torch
.
int32
,
device
=
'cuda'
)
eps
=
torch
.
full
([
num_fragments
],
eps
,
dtype
=
self
.
dtype
,
device
=
'cuda'
)
weight_decay
=
torch
.
full
([
num_fragments
],
weight_decay
,
dtype
=
self
.
dtype
,
device
=
'cuda'
)
# Apply Adam step
dummy_overflow_buf
=
torch
.
zeros
([
1
],
dtype
=
torch
.
int32
,
device
=
'cuda'
)
multi_tensor_applier
(
distributed_adam_cuda
.
multi_tensor_fused_adam
,
dummy_overflow_buf
,
list
(
zip
(
*
group_buffers
)),
beta1
,
beta2
,
bias_correction
,
eps
,
weight_decay
,
group
[
'lr'
],
scale
,
self
.
state
[
'step'
],
1
,
# Set to 0 to apply eps inside sqrt
)
# Deallocate buffers
del
buffers
bucket
[
'grads_shard'
]
=
None
# Allgather updated parameters
if
self
.
distributed_size
==
1
:
params_bucket
=
params_shard_copy
else
:
params_bucket
=
torch
.
zeros
(
[
self
.
bucket_size
],
dtype
=
self
.
param_sync_dtype
,
device
=
self
.
device
,
)
params_bucket_shards
=
[
params_bucket
[
i
*
self
.
shard_size
:(
i
+
1
)
*
self
.
shard_size
]
for
i
in
range
(
self
.
distributed_size
)
]
params_bucket_shards
[
self
.
distributed_rank
].
copy_
(
params_shard_copy
)
if
self
.
_all_gather_no_copy
:
no_copy_kwarg
=
{
'no_copy'
:
True
}
else
:
no_copy_kwarg
=
{}
torch
.
distributed
.
all_gather
(
params_bucket_shards
,
params_bucket_shards
[
self
.
distributed_rank
],
group
=
self
.
distributed_process_group
,
**
no_copy_kwarg
,
)
del
params_shard_copy
# Copy values to param buffers
params_in
=
[]
params_out
=
[]
for
fragment
in
bucket
[
'fragments'
]:
param_group_id
=
fragment
[
'param_group_id'
]
param_id
=
fragment
[
'param_id'
]
param
=
self
.
param_groups
[
param_group_id
][
'params'
][
param_id
]
bucket_start
,
bucket_end
=
fragment
[
'bucket_range'
]
param_start
,
param_end
=
fragment
[
'param_range'
]
params_in
.
append
(
params_bucket
[
bucket_start
:
bucket_end
])
params_out
.
append
(
param
.
view
(
-
1
)[
param_start
:
param_end
])
if
params_in
:
dummy_overflow_buf
=
torch
.
zeros
([
1
],
dtype
=
torch
.
int32
,
device
=
'cuda'
)
multi_tensor_applier
(
fused_adam_cuda
.
maybe_cast_mt
,
dummy_overflow_buf
,
[
params_in
,
params_out
],
)
del
params_bucket
,
params_in
,
params_out
# Synchronize pipeline streams
for
stream
in
self
.
_pipeline_streams
:
current_stream
.
wait_stream
(
stream
)
return
loss
return
loss
def
state_dict
(
self
):
"""
Returns a dict containing the current state of this :class:`DistributedFusedAdam` instance.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
# save step, master weights and first/second moments
state_dict
=
{}
state_dict
[
'step'
]
=
self
.
_step
state_dict
[
'fp32_p'
]
=
self
.
_fp32_p
state_dict
[
'fp32_m'
]
=
self
.
_fp32_m
state_dict
[
'fp32_v'
]
=
self
.
_fp32_v
return
state_dict
def
load_state_dict
(
self
,
state_dict
):
"""
Loads a state_dict created by an earlier call to state_dict().
If an DistributedFusedAdam instance was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``optimizer.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# restore step, master weights and first/second moments
self
.
_step
=
state_dict
[
'step'
]
self
.
_fp32_p
=
state_dict
[
'fp32_p'
].
to
(
device
=
"cuda"
)
self
.
_fp32_m
=
state_dict
[
'fp32_m'
].
to
(
device
=
"cuda"
)
self
.
_fp32_v
=
state_dict
[
'fp32_v'
].
to
(
device
=
"cuda"
)
self
.
_resume_from_checkpoint
=
True
apex/contrib/optimizers/distributed_fused_adam_v2.py
deleted
100644 → 0
View file @
265b451d
import
math
import
torch
import
importlib
import
amp_C
from
apex.multi_tensor_apply
import
multi_tensor_applier
class
DistributedFusedAdamV2
(
torch
.
optim
.
Optimizer
):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
use_mt (boolean, optional): use multi tensor apply for lower launch
latency. (default: False)
overlap_reductions(boolean, optional): whether to overlap reductions
with bprop (default: True)
num_prestats (integer, optional): number of fp64 stats that will be
reduced during first fp16 gradient reduction block.
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
eps_inside_sqrt
=
False
,
weight_decay
=
0.
,
max_grad_norm
=
0.
,
amsgrad
=
False
,
use_mt
=
False
,
amp_scale_adjustment
=
1.0
,
overlap_reductions
=
True
,
full_pipeline
=
True
,
compute_L2_grad_norm
=
False
,
distributed_weight_update
=
0
,
dwu_group_size
=
0
,
dwu_num_blocks
=
4
,
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_num_ag_pg
=
0
,
revert_method
=
1
,
flat_mt
=
False
,
dwu_num_chunks
=
4
,
predivide
=
True
,
e5m2_allgather
=
False
,
do_not_flatten_model
=
False
):
global
fused_adam_cuda
fused_adam_cuda
=
importlib
.
import_module
(
"fused_adam_cuda"
)
self
.
_amp_scale_adjustment
=
amp_scale_adjustment
if
use_mt
:
raise
RuntimeError
(
'DistributedFusedAdam does not support use_mt.'
)
if
amsgrad
:
raise
RuntimeError
(
'DistributedFusedAdam does not support the AMSGrad variant.'
)
defaults
=
dict
(
lr
=
lr
,
bias_correction
=
bias_correction
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
max_grad_norm
=
max_grad_norm
)
super
(
DistributedFusedAdamV2
,
self
).
__init__
(
params
,
defaults
)
self
.
eps_mode
=
0
if
eps_inside_sqrt
else
1
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
_has_overflow
=
False
assert
(
len
(
self
.
param_groups
)
==
1
),
"More than one parameter group is not supported."
# Way to revert a step
# 3 -> undo kernel + double buffer (debug, print norm of difference)
# 2 -> double buffer fp32 parameters
# 1 -> undo kernel
self
.
_revert_method
=
revert_method
if
self
.
_revert_method
>
1
:
print
(
"revert_method -> double buffer fp32 parameters, will consume more memory"
)
self
.
_last_step
=
False
self
.
_overlap_reductions
=
overlap_reductions
self
.
_global_scale
=
None
self
.
_num_blocks
=
dwu_num_blocks
self
.
_num_chunks
=
dwu_num_chunks
self
.
_predivide
=
predivide
self
.
_e5m2_allgather
=
e5m2_allgather
self
.
_do_not_flatten_model
=
do_not_flatten_model
self
.
_full_pipeline
=
full_pipeline
self
.
_compute_L2_grad_norm
=
compute_L2_grad_norm
self
.
_L2_grad_norm
=
None
self
.
_group_size
=
torch
.
cuda
.
device_count
()
if
dwu_group_size
<=
0
else
dwu_group_size
self
.
_world_size
=
torch
.
distributed
.
get_world_size
()
self
.
_num_groups
=
self
.
_world_size
//
self
.
_group_size
self
.
_rank_in_group
=
torch
.
distributed
.
get_rank
()
%
self
.
_group_size
p_offset
=
0
p_i
=
0
self
.
_param_state
=
None
self
.
_model_params
=
[]
self
.
_grads_info
=
[]
self
.
_grad_accs
=
[]
for
group
in
self
.
param_groups
:
self
.
_param_group
=
group
prev
=
None
for
p
in
group
[
'params'
]:
torch
.
distributed
.
broadcast
(
p
,
0
)
if
not
p
.
requires_grad
:
continue
self
.
_model_params
.
append
(
p
)
state
=
self
.
state
[
p
]
if
len
(
state
)
==
0
:
state
[
'step'
]
=
0
if
self
.
_param_state
is
None
:
self
.
_param_state
=
state
p_grads_size
=
p
.
numel
()
def
wrapper
(
param
,
param_i
,
param_grads_size
,
param_offset
):
param_tmp
=
param
.
expand_as
(
param
)
grad_acc
=
param_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
def
allreduce_hook
(
*
unused
):
self
.
_do_overlapped_reduction
(
param_i
,
param_grads_size
,
param_offset
,
param
)
grad_acc
.
register_hook
(
allreduce_hook
)
self
.
_grad_accs
.
append
(
grad_acc
)
self
.
_grads_info
.
append
({
"param_grads_size"
:
p_grads_size
,
"param_offset"
:
p_offset
})
wrapper
(
p
,
p_i
,
p_grads_size
,
p_offset
)
p_offset
+=
p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if
prev
is
not
None
and
(
prev
.
data_ptr
()
+
prev
.
numel
()
*
prev
.
element_size
()
!=
p
.
data_ptr
()):
p_offset
=
((
p_offset
+
63
)
//
64
)
*
64
prev
=
p
p_i
+=
1
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
self
.
_flat_mt
=
flat_mt
self
.
_grads
=
[]
if
self
.
_overlap_reductions
:
self
.
_current_block
=
self
.
_num_blocks
self
.
_net_total_param_size
=
p_offset
self
.
_total_param_size
=
p_offset
dwu_min_page_size
=
256
*
self
.
_num_blocks
*
self
.
_num_chunks
*
self
.
_group_size
self
.
_total_param_size
=
((
self
.
_total_param_size
+
dwu_min_page_size
-
1
)
//
dwu_min_page_size
)
*
dwu_min_page_size
self
.
_block_size
=
self
.
_total_param_size
//
self
.
_num_blocks
self
.
_shard_size
=
self
.
_block_size
//
self
.
_group_size
self
.
_chunk_size
=
self
.
_shard_size
//
self
.
_num_chunks
print
(
"self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._shard_size=%d, self._chunk_size=%d"
%
(
self
.
_net_total_param_size
,
self
.
_total_param_size
,
dwu_min_page_size
,
self
.
_block_size
,
self
.
_shard_size
,
self
.
_chunk_size
))
self
.
_low_param_i
=
[
0
]
*
self
.
_num_blocks
for
block_id
in
range
(
self
.
_num_blocks
-
1
,
-
1
,
-
1
):
p_i
=
len
(
self
.
_grads_info
)
-
1
while
p_i
>
0
and
self
.
_grads_info
[
p_i
][
"param_offset"
]
>
block_id
*
self
.
_block_size
:
p_i
-=
1
self
.
_low_param_i
[
block_id
]
=
p_i
print
(
self
.
_low_param_i
)
self
.
_flat_grads
=
torch
.
zeros
([
self
.
_total_param_size
],
dtype
=
torch
.
float16
,
device
=
'cuda'
)
self
.
_new_params
=
torch
.
zeros
([
self
.
_total_param_size
],
dtype
=
torch
.
uint8
if
self
.
_e5m2_allgather
else
torch
.
float16
,
device
=
'cuda'
)
self
.
_mega_shard_size
=
self
.
_num_blocks
*
self
.
_num_chunks
*
self
.
_chunk_size
self
.
_fp32_p
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_fp32_m
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_fp32_v
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
# FIXME: Rethink fp16 label since it's either uint8 or fp16
self
.
_fp16_p
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
uint8
if
self
.
_e5m2_allgather
else
torch
.
float16
,
device
=
'cuda'
)
self
.
_fp16_g
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float16
,
device
=
'cuda'
)
self
.
_individual_flat_grads
=
[]
for
p_i
,
(
grads_info
,
p
)
in
enumerate
(
zip
(
self
.
_grads_info
,
self
.
_model_params
)):
self
.
_individual_flat_grads
.
append
(
self
.
_flat_grads
[
grads_info
[
"param_offset"
]:
grads_info
[
"param_offset"
]
+
grads_info
[
"param_grads_size"
]].
view_as
(
p
))
def
_flat_split
(
p
):
def
__blockify
(
p
):
return
[
p
[
block_id
*
self
.
_block_size
:(
block_id
+
1
)
*
self
.
_block_size
]
for
block_id
in
range
(
self
.
_num_blocks
)]
def
__shardify
(
p
):
return
[
p
[
shard_id
*
self
.
_shard_size
:(
shard_id
+
1
)
*
self
.
_shard_size
]
for
shard_id
in
range
(
self
.
_group_size
)]
def
__chunkify
(
p
):
return
[
p
[
chunk_id
*
self
.
_chunk_size
:(
chunk_id
+
1
)
*
self
.
_chunk_size
]
for
chunk_id
in
range
(
self
.
_group_size
)]
list_of_blocks
=
__blockify
(
self
.
_flat_grads
)
list_of_list_of_shards
=
[
__shardify
(
block
)
for
block
in
list_of_blocks
]
list_of_list_of_list_of_chunks
=
[[
__chunkify
(
shard
)
for
shard
in
shards
]
for
shards
in
list_of_list_of_shards
]
return
list_of_blocks
,
list_of_list_of_shards
,
list_of_list_of_list_of_chunks
self
.
_flat_grads_blocks
,
self
.
_flat_grads_shards
,
self
.
_flat_grads_chunks
=
_flat_split
(
self
.
_flat_grads
)
def
_full_packed_split
(
p
):
def
__shardify
(
p
):
return
[
p
[
mega_shard
*
self
.
_mega_shard_size
:(
mega_shard
+
1
)
*
self
.
_mega_shard_size
]
for
mega_shard
in
range
(
self
.
_group_size
)]
def
__blockify
(
p
):
return
[
p
[
block_id
*
self
.
_num_chunks
*
self
.
_chunk_size
:(
block_id
+
1
)
*
self
.
_num_chunks
*
self
.
_chunk_size
]
for
block_id
in
range
(
self
.
_num_blocks
)]
def
__chunkify
(
p
):
return
[
p
[
chunk_id
*
self
.
_chunk_size
:(
chunk_id
+
1
)
*
self
.
_chunk_size
]
for
chunk_id
in
range
(
self
.
_num_chunks
)]
list_of_mega_shards
=
__shardify
(
p
)
list_of_list_of_mega_blocks
=
[
__blockify
(
mega_shard
)
for
mega_shard
in
list_of_mega_shards
]
list_of_list_of_list_of_mega_chunks
=
[[
__chunkify
(
mega_block
)
for
mega_block
in
mega_blocks
]
for
mega_blocks
in
list_of_list_of_mega_blocks
]
return
list_of_mega_shards
,
list_of_list_of_mega_blocks
,
list_of_list_of_list_of_mega_chunks
self
.
_new_params_mega_shards
,
self
.
_new_params_mega_blocks
,
self
.
_new_params_mega_chunks
=
_full_packed_split
(
self
.
_new_params
)
def
_packed_split
(
p
):
def
__packed_blockify
(
p
):
packed_block_size
=
self
.
_num_chunks
*
self
.
_chunk_size
return
[
p
[
block_id
*
packed_block_size
:(
block_id
+
1
)
*
packed_block_size
]
for
block_id
in
range
(
self
.
_num_blocks
)]
def
__packed_chunkify
(
p
):
return
[
p
[
chunk_id
*
self
.
_chunk_size
:(
chunk_id
+
1
)
*
self
.
_chunk_size
]
for
chunk_id
in
range
(
self
.
_num_chunks
)]
list_of_blocks
=
__packed_blockify
(
p
)
list_of_list_of_chunks
=
[
__packed_chunkify
(
block
)
for
block
in
list_of_blocks
]
return
list_of_blocks
,
list_of_list_of_chunks
self
.
_fp32_p_blocks
,
self
.
_fp32_p_chunks
=
_packed_split
(
self
.
_fp32_p
)
self
.
_fp32_m_blocks
,
self
.
_fp32_m_chunks
=
_packed_split
(
self
.
_fp32_m
)
self
.
_fp32_v_blocks
,
self
.
_fp32_v_chunks
=
_packed_split
(
self
.
_fp32_v
)
self
.
_fp16_p_blocks
,
self
.
_fp16_p_chunks
=
_packed_split
(
self
.
_fp16_p
)
self
.
_fp16_g_blocks
,
self
.
_fp16_g_chunks
=
_packed_split
(
self
.
_fp16_g
)
# current arrangement
#
# self._flat_grads
# self._flat_grads_blocks [x self._num_blocks, self._block_size]
# self._flat_grads_chunks [x self._num_chunks, self._chunk_size]
# self._flat_grads_shards [x self._group_size, self._shard_size]
#
# self._new_params
# self._new_params_mega_shards [x self._group_size, self._num_blocks*self._num_chunks*self._shard_size]
# self._new_params_mega_blocks [x self._num_blocks, self._num_chunks*self._shard_size]
# self._new_params_mega_chunks [x self._num_chunks, self._shard_size]
#
# self._fp32_p
# self._fp32_p_blocks [x self._num_blocks, self._num_chunks*self._shard_size]
# self._fp32_p_chunks [x self._num_chunks, self._shard_size]
# each chunk contains one shard
# same for self._fp32_m, self._fp32_v, self._fp16_p and self._fp16_g
#
# Usage:
#
# for chunk_id in range(self._num_chunks):
# works[chunk_id] = torch.distributed.reduce_scatter(self._flat_grads_chunks[block_id][chunk_id], self._fp16_g_chunks[block_id][chunk_id], ...)
#
# ----------------------------------------------------------------------------------------
#
# new arrangement
#
# NB! New equations for self._shard_size and self._chunk_size
#
# self._flat_grads
# self._flat_grads_blocks [x self._num_blocks, self._block_size]
# self._flat_grads_shards [x self._group_size, self._shard_size]
# self._flat_grads_chunks [x self._num_chunks, self._chunk_size]
#
# self._new_params
# self._new_params_mega_shards [x self._group_size, self._num_blocks*self._num_chunks*self._chunk_size]
# self._new_params_mega_blocks [x self._num_blocks, self._num_chunks*self._chunk_size]
# self._new_params_mega_chunks [x self._num_chunks, self._chunk_size]
#
# self._fp32_p
# self._fp32_p_blocks [x self._num_blocks, self._num_chunks*self._chunk_size]
# self._fp32_p_chunks [x self._num_chunks, self._chunk_size]
# same for self._fp32_m, self._fp32_v, self._fp16_p and self._fp16_g
#
# Usage:
#
# work = torch.distributed.reduce_scatter(self._flat_grads_blocks[block_id], self._fp16_g[block_id], ...)
# for chunk_id in range(self._num_chunks):
# work.wait()
# works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id], ...)
# or
# work.wait()
# works[0] = torch.distributed.all_reduce(self._fp16_g_blocks[block_id], ...)
#
# This paragraph does two things:
# 1) Copy model parameters into master buffer
# 2) Create tensor lists for unpacking new parameter tensor after all-gather
self
.
_packed_flat_to_model_params
=
[]
for
shard_id
in
range
(
self
.
_group_size
):
for
block_id
in
range
(
self
.
_num_blocks
):
flat_shard_start
=
(
block_id
*
self
.
_group_size
+
shard_id
)
*
self
.
_shard_size
flat_shard_end
=
flat_shard_start
+
self
.
_shard_size
for
p
,
grads_info
in
zip
(
self
.
_model_params
,
self
.
_grads_info
):
flat_grad_start
=
grads_info
[
"param_offset"
]
flat_grad_end
=
flat_grad_start
+
grads_info
[
"param_grads_size"
]
clipped_start
=
(
lambda
a
,
b
:
a
if
a
>
b
else
b
)(
flat_grad_start
,
flat_shard_start
)
clipped_end
=
(
lambda
a
,
b
:
a
if
a
<
b
else
b
)(
flat_grad_end
,
flat_shard_end
)
if
clipped_start
<
clipped_end
:
grad_offset
=
clipped_start
-
flat_grad_start
grad_length
=
clipped_end
-
clipped_start
shard_offset
=
clipped_start
-
flat_shard_start
model_param_fragment
=
p
.
view
(
-
1
)[
grad_offset
:
grad_offset
+
grad_length
]
new_param_packed_fragment
=
self
.
_new_params_mega_blocks
[
shard_id
][
block_id
][
shard_offset
:
shard_offset
+
grad_length
]
self
.
_packed_flat_to_model_params
.
append
(
(
new_param_packed_fragment
,
model_param_fragment
)
)
if
shard_id
==
self
.
_rank_in_group
:
# copy model parameters into master buffer
master_param_fragment
=
self
.
_fp32_p_blocks
[
block_id
][
shard_offset
:
shard_offset
+
grad_length
]
print
(
"model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s"
%
(
str
(
model_param_fragment
.
size
()),
str
(
new_param_packed_fragment
.
size
()),
str
(
master_param_fragment
.
size
())))
master_param_fragment
.
copy_
(
model_param_fragment
)
p_in
,
p_out
=
zip
(
*
self
.
_packed_flat_to_model_params
)
self
.
_packed_flat_to_model_params
=
[
p_in
,
p_out
]
self
.
_distributed_weight_update
=
distributed_weight_update
# Is this still needed?
self
.
_num_rs_pg
=
dwu_num_rs_pg
self
.
_num_ar_pg
=
dwu_num_ar_pg
self
.
_num_ag_pg
=
dwu_num_ag_pg
if
self
.
_num_groups
>
1
:
self
.
_ar_pg
=
[]
for
dev_i
in
range
(
self
.
_group_size
):
ranks
=
[
dev_i
+
j
*
self
.
_group_size
for
j
in
range
(
self
.
_num_groups
)]
for
i
in
range
(
self
.
_num_ar_pg
):
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_ar_pg
.
append
(
grp
)
self
.
_ar_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_ar_pg
)]
for
ar_pg
in
self
.
_ar_pg
:
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
ar_pg
)
rs_ranks
=
[]
for
group_i
in
range
(
self
.
_num_groups
):
rs_ranks
.
append
([
group_i
*
self
.
_group_size
+
j
for
j
in
range
(
self
.
_group_size
)])
self
.
_rs_pg
=
[]
for
group_i
in
range
(
self
.
_num_groups
):
ranks
=
rs_ranks
[
group_i
]
for
i
in
range
(
self
.
_num_rs_pg
):
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_rs_pg
.
append
(
grp
)
if
self
.
_compute_L2_grad_norm
and
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_l2_grad_norm_pg
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
self
.
_l2_grad_norm_pg
)
self
.
_rs_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_rs_pg
)]
for
rs_pg
in
self
.
_rs_pg
:
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
rs_pg
)
if
self
.
_num_ag_pg
==
0
:
self
.
_ag_pg
=
self
.
_rs_pg
self
.
_ag_st
=
self
.
_rs_st
self
.
_num_ag_pg
=
self
.
_num_rs_pg
else
:
self
.
_ag_pg
=
[]
for
group_i
in
range
(
self
.
_num_groups
):
ranks
=
rs_ranks
[
group_i
]
for
i
in
range
(
self
.
_num_ag_pg
):
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_ag_pg
.
append
(
grp
)
self
.
_ag_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_ag_pg
)]
for
ag_pg
in
self
.
_ag_pg
:
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
ag_pg
)
self
.
_l2_grad_norm_st
=
torch
.
cuda
.
Stream
()
if
self
.
_compute_L2_grad_norm
else
None
self
.
_completion_st
=
torch
.
cuda
.
Stream
()
self
.
_reductions_works
=
[
None
]
*
self
.
_num_blocks
self
.
_allgather_works
=
[
None
]
*
self
.
_num_blocks
import
inspect
assert
(
'no_copy'
in
inspect
.
getfullargspec
(
torch
.
distributed
.
reduce_scatter
).
args
),
"This version of c10d does not support no_copy option"
def
set_last_step
(
self
,
last_step
):
self
.
_last_step
=
last_step
def
_get_flush_block
(
self
):
flush_block
=
[]
if
self
.
_current_block
>
0
and
self
.
_grads_generated
[
self
.
_low_param_i
[
self
.
_current_block
-
1
]]:
num_grads
=
len
(
self
.
_grads_generated
)
contiguous_idx
=
num_grads
while
contiguous_idx
>
0
and
self
.
_grads_generated
[
contiguous_idx
-
1
]:
contiguous_idx
-=
1
if
contiguous_idx
<
num_grads
and
self
.
_grads_info
[
contiguous_idx
][
"param_offset"
]
<=
(
self
.
_current_block
-
1
)
*
self
.
_block_size
:
self
.
_current_block
-=
1
start
=
self
.
_current_block
*
self
.
_block_size
end
=
(
self
.
_current_block
+
1
)
*
self
.
_block_size
flush_block
=
[
start
,
end
]
return
flush_block
def
_pipeline_block_reductions
(
self
,
block_id
):
self
.
_flatten_grad_mt
(
1.0
/
self
.
_world_size
if
self
.
_predivide
else
1.0
)
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works
=
[
None
]
*
self
.
_num_chunks
rs_stream
=
self
.
_rs_st
[
block_id
%
self
.
_num_rs_pg
]
rs_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
rs_stream
):
rs_work
=
torch
.
distributed
.
reduce_scatter
(
self
.
_fp16_g_blocks
[
block_id
],
self
.
_flat_grads_shards
[
block_id
],
group
=
self
.
_rs_pg
[
block_id
%
self
.
_num_rs_pg
],
async_op
=
True
,
no_copy
=
True
)
for
chunk_id
in
range
(
self
.
_num_chunks
):
works
[
chunk_id
]
=
rs_work
# Reduction across nodes for each rank
if
self
.
_num_groups
>
1
:
for
chunk_id
in
range
(
self
.
_num_chunks
):
glob_chunk_id
=
block_id
*
self
.
_num_chunks
+
chunk_id
ar_stream
=
self
.
_ar_st
[
glob_chunk_id
%
self
.
_num_ar_pg
]
with
torch
.
cuda
.
stream
(
ar_stream
):
rs_work
.
wait
()
works
[
chunk_id
]
=
torch
.
distributed
.
all_reduce
(
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
],
group
=
self
.
_ar_pg
[
glob_chunk_id
%
self
.
_num_ar_pg
],
async_op
=
True
)
self
.
_reductions_works
[
block_id
]
=
works
# Optionally compute L2 grad norm
if
self
.
_compute_L2_grad_norm
and
block_id
==
0
:
with
torch
.
cuda
.
stream
(
self
.
_l2_grad_norm_st
):
for
block_id
in
range
(
self
.
_num_blocks
):
for
chunk_id
in
range
(
self
.
_num_chunks
):
self
.
_reductions_works
[
block_id
][
chunk_id
].
wait
()
# Since the packed format is contiguous after reductions, only one norm is needed
l2_grad_norm_sq
=
torch
.
empty
([
1
],
device
=
'cuda'
)
l2_grad_norm_sq
=
self
.
_fp16_g
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
torch
.
distributed
.
all_reduce
(
l2_grad_norm_sq
,
group
=
self
.
_l2_grad_norm_pg
)
self
.
_L2_grad_norm
=
l2_grad_norm_sq
.
sqrt
().
item
()
def
__launch_step_kernel
(
self
,
p
,
p_copy
,
m
,
v
,
g
):
combined_scale
=
self
.
_global_scale
if
self
.
_param_group
[
'max_grad_norm'
]
>
0
and
math
.
isfinite
(
self
.
L2_grad_norm
):
combined_scale
=
self
.
_param_group
[
'max_grad_norm'
]
/
(
self
.
L2_grad_norm
/
self
.
_global_scale
+
1e-6
)
combined_scale
=
self
.
_global_scale
/
min
(
1
,
combined_scale
)
bias_correction
=
1
if
self
.
_param_group
[
'bias_correction'
]
else
0
beta1
,
beta2
=
self
.
_param_group
[
'betas'
]
fused_adam_cuda
.
reversible_adam
(
p
,
p_copy
,
m
,
v
,
g
,
self
.
_param_group
[
'lr'
],
beta1
,
beta2
,
self
.
_param_group
[
'eps'
],
combined_scale
,
self
.
_param_state
[
'step'
]
+
1
,
self
.
eps_mode
,
bias_correction
,
self
.
_param_group
[
'weight_decay'
])
def
_pipeline_block_step
(
self
,
block_id
):
# Call step kernel once per block
ag_stream
=
self
.
_ag_st
[
block_id
%
self
.
_num_ag_pg
]
with
torch
.
cuda
.
stream
(
ag_stream
):
for
chunk_id
in
range
(
self
.
_num_chunks
):
self
.
_reductions_works
[
block_id
][
chunk_id
].
wait
()
self
.
__launch_step_kernel
(
self
.
_fp32_p_blocks
[
block_id
],
self
.
_fp16_p_blocks
[
block_id
],
self
.
_fp32_m_blocks
[
block_id
],
self
.
_fp32_v_blocks
[
block_id
],
self
.
_fp16_g_blocks
[
block_id
])
# Call all-gather once per step.
# FIXME: Determine which is faster, one all-gather per block or a single all-gather at end
if
block_id
==
0
:
for
other_ag_stream
in
self
.
_ag_st
:
self
.
_completion_st
.
wait_stream
(
other_ag_stream
)
with
torch
.
cuda
.
stream
(
self
.
_completion_st
):
torch
.
distributed
.
all_gather
(
self
.
_new_params_mega_shards
,
self
.
_fp16_p
,
group
=
self
.
_ag_pg
[
0
],
no_copy
=
True
)
def
_pipeline_step
(
self
):
# Call step kernel once per step
# Call all-gather once per step
with
torch
.
cuda
.
stream
(
self
.
_completion_st
):
for
block_id
in
range
(
self
.
_num_blocks
):
for
chunk_id
in
range
(
self
.
_num_chunks
):
self
.
_reductions_works
[
block_id
][
chunk_id
].
wait
()
self
.
__launch_step_kernel
(
self
.
_fp32_p
,
self
.
_fp16_p
,
self
.
_fp32_m
,
self
.
_fp32_v
,
self
.
_fp16_g
)
torch
.
distributed
.
all_gather
(
self
.
_new_params_mega_shards
,
self
.
_fp16_p
,
group
=
self
.
_ag_pg
[
0
],
no_copy
=
True
)
def
_flatten_grad_mt
(
self
,
scale
):
if
self
.
_flat_mt
and
len
(
self
.
_grads
)
>
0
:
self
.
_overflow_buf
.
zero_
()
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
self
.
_overflow_buf
,
list
(
zip
(
*
self
.
_grads
)),
scale
)
self
.
_grads
=
[]
def
_do_overlapped_reduction
(
self
,
param_i
,
param_grads_size
,
param_offset
,
param
):
# handle overlapped reductions
if
self
.
_flat_mt
:
self
.
_grads
.
append
(
(
param
.
grad
,
self
.
_individual_flat_grads
[
param_i
])
)
else
:
torch
.
div
(
param
.
grad
,
self
.
_world_size
if
self
.
_predivide
else
1.0
,
out
=
self
.
_individual_flat_grads
[
param_i
])
self
.
_grads_generated
[
param_i
]
=
True
if
not
self
.
_last_step
:
if
self
.
_overlap_reductions
:
flush_block
=
self
.
_get_flush_block
()
while
flush_block
:
block_id
=
flush_block
[
0
]
//
self
.
_block_size
self
.
_pipeline_block_reductions
(
block_id
)
if
self
.
_full_pipeline
:
self
.
_pipeline_block_step
(
block_id
)
flush_block
=
self
.
_get_flush_block
()
def
set_global_scale
(
self
,
global_scale
):
"""Set global scale.
"""
self
.
_global_scale
=
global_scale
@
property
def
global_scale
(
self
):
return
self
.
_global_scale
@
property
def
has_overflow
(
self
):
"""Check if overflows were detected by any call to step(...) method.
Clears the overflow flag.
"""
has_overflow
=
self
.
_has_overflow
self
.
_has_overflow
=
False
return
has_overflow
@
property
def
peek_overflow
(
self
):
"""Check if overflows were detected by any call to step(...) method.
Does not clear overflow flag.
"""
return
self
.
_has_overflow
def
strided_check_finite
(
self
,
output_params
,
stride
=
1
,
start
=-
1
,
end
=-
1
,
clear
=
True
):
"""Strided check for overflow.
You can get status by calling has_overflow.
"""
if
start
>=
0
and
start
<
end
:
out_p
=
output_params
[
start
:
end
]
else
:
out_p
=
output_params
fused_adam_cuda
.
strided_check_finite
(
self
.
_overflow_buf
,
out_p
,
stride
,
1
if
clear
else
0
)
self
.
_has_overflow
=
False
if
self
.
_overflow_buf
.
item
()
==
0
else
True
return
self
.
_has_overflow
@
property
def
L2_grad_norm
(
self
):
if
self
.
_compute_L2_grad_norm
:
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_l2_grad_norm_st
)
return
self
.
_L2_grad_norm
else
:
return
None
def
complete_reductions
(
self
):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
if
self
.
_last_step
:
# zero out gradients that have not been completed yet
for
param_i
,
grad_generated
in
enumerate
(
self
.
_grads_generated
):
if
not
grad_generated
:
grad_info
=
self
.
_grads_info
[
param_i
]
param_offset
=
grad_info
[
"param_offset"
]
param_size
=
grad_info
[
"param_grads_size"
]
self
.
_flat_grads
[
param_offset
:
param_offset
+
param_size
].
zero_
()
self
.
_grads_generated
[
param_i
]
=
True
if
self
.
_last_step
or
not
self
.
_overlap_reductions
:
# nothing done so far, run full pipeline after reductions
for
block_id
in
range
(
self
.
_num_blocks
-
1
,
-
1
,
-
1
):
self
.
_pipeline_block_reductions
(
block_id
)
if
self
.
_compute_L2_grad_norm
:
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_l2_grad_norm_st
)
self
.
_current_block
=
self
.
_num_blocks
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
def
revert_step
(
self
):
"""Revert effect of previously calling partial_step.
"""
# Call undo kernel once per step
combined_scale
=
self
.
_global_scale
if
self
.
_param_group
[
'max_grad_norm'
]
>
0
and
math
.
isfinite
(
self
.
L2_grad_norm
):
combined_scale
=
self
.
_param_group
[
'max_grad_norm'
]
/
(
self
.
L2_grad_norm
/
self
.
_global_scale
+
1e-6
)
combined_scale
=
self
.
_global_scale
/
min
(
1
,
combined_scale
)
bias_correction
=
1
if
self
.
_param_group
[
'bias_correction'
]
else
0
beta1
,
beta2
=
self
.
_param_group
[
'betas'
]
fused_adam_cuda
.
maybe_adam_undo
(
torch
.
empty
([
0
]),
self
.
_fp32_p
,
self
.
_fp32_m
,
self
.
_fp32_v
,
self
.
_fp16_g
,
self
.
_param_group
[
'lr'
],
beta1
,
beta2
,
self
.
_param_group
[
'eps'
],
combined_scale
,
self
.
_param_state
[
'step'
]
+
1
,
self
.
eps_mode
,
bias_correction
,
self
.
_param_group
[
'weight_decay'
])
def
step
(
self
,
closure
=
None
,
skip_overflow_check
=
False
):
loss
=
None
if
closure
is
not
None
:
loss
=
closure
()
if
self
.
_last_step
or
not
self
.
_overlap_reductions
or
not
self
.
_full_pipeline
:
self
.
_pipeline_step
()
with
torch
.
cuda
.
stream
(
self
.
_completion_st
):
# Check for overflow
# Store state for loss scaler calculation
has_overflow
=
False
if
skip_overflow_check
else
self
.
strided_check_finite
(
self
.
_new_params
,
stride
=
self
.
_shard_size
,
start
=
0
,
end
=
self
.
_net_total_param_size
)
if
has_overflow
:
self
.
revert_step
()
else
:
# Copy self._new_params to model params
for
p
in
self
.
_model_params
:
self
.
state
[
p
][
'step'
]
+=
1
multi_tensor_applier
(
fused_adam_cuda
.
maybe_cast_mt
,
self
.
_overflow_buf
,
self
.
_packed_flat_to_model_params
)
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_completion_st
)
self
.
_reductions_works
=
[
None
]
*
self
.
_num_blocks
self
.
_allgather_works
=
[
None
]
*
self
.
_num_blocks
return
loss
apex/contrib/optimizers/distributed_fused_adam_v3.py
deleted
100644 → 0
View file @
265b451d
import
math
import
torch
import
importlib
import
amp_C
from
apex.multi_tensor_apply
import
multi_tensor_applier
class
DistributedFusedAdamV3
(
torch
.
optim
.
Optimizer
):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
use_mt (boolean, optional): use multi tensor apply for lower launch
latency. (default: False)
overlap_reductions(boolean, optional): whether to overlap reductions
with bprop (default: True)
num_prestats (integer, optional): number of fp64 stats that will be
reduced during first fp16 gradient reduction block.
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
eps_inside_sqrt
=
False
,
weight_decay
=
0.
,
max_grad_norm
=
0.
,
amsgrad
=
False
,
use_mt
=
False
,
amp_scale_adjustment
=
1.0
,
overlap_reductions
=
True
,
full_pipeline
=
True
,
compute_L2_grad_norm
=
False
,
distributed_weight_update
=
0
,
dwu_group_size
=
0
,
dwu_num_blocks
=
4
,
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_num_ag_pg
=
0
,
revert_method
=
1
,
flat_mt
=
False
,
dwu_num_chunks
=
4
,
predivide
=
True
,
e5m2_allgather
=
False
,
do_not_flatten_model
=
False
):
global
fused_adam_cuda
fused_adam_cuda
=
importlib
.
import_module
(
"fused_adam_cuda"
)
self
.
_amp_scale_adjustment
=
amp_scale_adjustment
if
use_mt
:
raise
RuntimeError
(
'DistributedFusedAdam does not support use_mt.'
)
if
amsgrad
:
raise
RuntimeError
(
'DistributedFusedAdam does not support the AMSGrad variant.'
)
defaults
=
dict
(
lr
=
lr
,
bias_correction
=
bias_correction
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
max_grad_norm
=
max_grad_norm
)
super
(
DistributedFusedAdamV3
,
self
).
__init__
(
params
,
defaults
)
self
.
eps_mode
=
0
if
eps_inside_sqrt
else
1
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
assert
(
len
(
self
.
param_groups
)
==
1
),
"More than one parameter group is not supported."
# Way to revert a step
# 3 -> undo kernel + double buffer (debug, print norm of difference)
# 2 -> double buffer fp32 parameters
# 1 -> undo kernel
self
.
_revert_method
=
revert_method
if
self
.
_revert_method
>
1
:
print
(
"revert_method -> double buffer fp32 parameters, will consume more memory"
)
self
.
_last_step
=
False
self
.
_overlap_reductions
=
overlap_reductions
self
.
_global_scale
=
None
self
.
_num_blocks
=
dwu_num_blocks
self
.
_predivide
=
predivide
self
.
_e5m2_allgather
=
e5m2_allgather
self
.
_do_not_flatten_model
=
do_not_flatten_model
self
.
_full_pipeline
=
full_pipeline
self
.
_L2_grad_norm
=
None
self
.
_group_size
=
torch
.
cuda
.
device_count
()
if
dwu_group_size
<=
0
else
dwu_group_size
self
.
_world_size
=
torch
.
distributed
.
get_world_size
()
self
.
_num_groups
=
self
.
_world_size
//
self
.
_group_size
self
.
_rank_in_group
=
torch
.
distributed
.
get_rank
()
%
self
.
_group_size
p_offset
=
0
p_i
=
0
self
.
_param_state
=
None
self
.
_model_params
=
[]
self
.
_grads_info
=
[]
self
.
_grad_accs
=
[]
for
group
in
self
.
param_groups
:
self
.
_param_group
=
group
prev
=
None
for
p
in
group
[
'params'
]:
torch
.
distributed
.
broadcast
(
p
,
0
)
if
not
p
.
requires_grad
:
continue
self
.
_model_params
.
append
(
p
)
state
=
self
.
state
[
p
]
if
len
(
state
)
==
0
:
state
[
'step'
]
=
0
if
self
.
_param_state
is
None
:
self
.
_param_state
=
state
p_grads_size
=
p
.
numel
()
def
wrapper
(
param
,
param_i
,
param_grads_size
,
param_offset
):
param_tmp
=
param
.
expand_as
(
param
)
grad_acc
=
param_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
def
allreduce_hook
(
*
unused
):
self
.
_do_overlapped_reduction
(
param_i
,
param_grads_size
,
param_offset
,
param
)
grad_acc
.
register_hook
(
allreduce_hook
)
self
.
_grad_accs
.
append
(
grad_acc
)
self
.
_grads_info
.
append
({
"param_grads_size"
:
p_grads_size
,
"param_offset"
:
p_offset
})
wrapper
(
p
,
p_i
,
p_grads_size
,
p_offset
)
p_offset
+=
p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if
prev
is
not
None
and
(
prev
.
data_ptr
()
+
prev
.
numel
()
*
prev
.
element_size
()
!=
p
.
data_ptr
()):
p_offset
=
((
p_offset
+
63
)
//
64
)
*
64
prev
=
p
p_i
+=
1
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
self
.
_flat_mt
=
flat_mt
self
.
_grads
=
[]
self
.
_current_block
=
self
.
_num_blocks
self
.
_net_total_param_size
=
p_offset
self
.
_total_param_size
=
p_offset
dwu_min_page_size
=
256
*
self
.
_num_blocks
*
self
.
_group_size
self
.
_total_param_size
=
((
self
.
_total_param_size
+
dwu_min_page_size
-
1
)
//
dwu_min_page_size
)
*
dwu_min_page_size
self
.
_block_size
=
self
.
_total_param_size
//
self
.
_num_blocks
self
.
_shard_size
=
self
.
_total_param_size
//
self
.
_group_size
print
(
"self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._shard_size=%d"
%
(
self
.
_net_total_param_size
,
self
.
_total_param_size
,
dwu_min_page_size
,
self
.
_block_size
,
self
.
_shard_size
))
self
.
_low_param_i
=
[
0
]
*
self
.
_num_blocks
for
block_id
in
range
(
self
.
_num_blocks
-
1
,
-
1
,
-
1
):
p_i
=
len
(
self
.
_grads_info
)
-
1
while
p_i
>
0
and
self
.
_grads_info
[
p_i
][
"param_offset"
]
>
block_id
*
self
.
_block_size
:
p_i
-=
1
self
.
_low_param_i
[
block_id
]
=
p_i
print
(
self
.
_low_param_i
)
self
.
_flat_grads
=
torch
.
zeros
([
self
.
_total_param_size
],
dtype
=
torch
.
float16
,
device
=
'cuda'
)
self
.
_flat_params
=
torch
.
zeros_like
(
self
.
_flat_grads
)
def
_flat_split
(
flat
):
def
__flat_blockify
(
flat
):
return
[
flat
[
block_id
*
self
.
_block_size
:(
block_id
+
1
)
*
self
.
_block_size
]
for
block_id
in
range
(
self
.
_num_blocks
)]
def
__flat_shardify
(
flat
):
return
[
flat
[
shard_id
*
self
.
_shard_size
:(
shard_id
+
1
)
*
self
.
_shard_size
]
for
shard_id
in
range
(
self
.
_group_size
)]
return
__flat_blockify
(
flat
),
__flat_shardify
(
flat
)
self
.
_flat_grads_blocks
,
self
.
_flat_grads_shards
=
_flat_split
(
self
.
_flat_grads
)
self
.
_flat_params_blocks
,
self
.
_flat_params_shards
=
_flat_split
(
self
.
_flat_params
)
# master params
self
.
_fp32_p
=
torch
.
zeros
([
self
.
_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_fp32_m
=
torch
.
zeros
([
self
.
_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_fp32_v
=
torch
.
zeros
([
self
.
_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
# copy model params to flat_params and set_ model params to flat_params.
self
.
_individual_flat_grads
=
[]
with
torch
.
no_grad
():
for
p
,
grads_info
in
zip
(
self
.
_model_params
,
self
.
_grads_info
):
start
=
grads_info
[
"param_offset"
]
end
=
start
+
grads_info
[
"param_grads_size"
]
flat_p
=
self
.
_flat_params
[
start
:
end
].
view_as
(
p
)
flat_p
.
copy_
(
p
)
p
.
set_
(
flat_p
)
flat_grad
=
self
.
_flat_grads
[
start
:
end
]
self
.
_individual_flat_grads
.
append
(
flat_grad
)
self
.
_fp32_p
.
copy_
(
self
.
_flat_params_shards
[
self
.
_rank_in_group
].
float
())
self
.
_dwu_st
=
torch
.
cuda
.
Stream
()
self
.
_l2_grad_norm_st
=
torch
.
cuda
.
Stream
()
for
group_i
in
range
(
self
.
_num_groups
):
ranks
=
[
group_i
*
self
.
_group_size
+
local_rank
for
local_rank
in
range
(
self
.
_group_size
)]
pg
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_ag_pg
=
pg
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
self
.
_ag_pg
)
import
inspect
assert
(
'no_copy'
in
inspect
.
getfullargspec
(
torch
.
distributed
.
reduce_scatter
).
args
),
"This version of c10d does not support no_copy option"
@
property
def
has_overflow
(
self
):
return
True
if
not
self
.
L2_grad_norm
is
None
and
not
math
.
isfinite
(
self
.
L2_grad_norm
)
else
False
def
set_last_step
(
self
,
last_step
):
self
.
_last_step
=
last_step
def
_get_flush_block
(
self
):
flush_block
=
[]
if
self
.
_current_block
>
0
and
self
.
_grads_generated
[
self
.
_low_param_i
[
self
.
_current_block
-
1
]]:
num_grads
=
len
(
self
.
_grads_generated
)
contiguous_idx
=
num_grads
while
contiguous_idx
>
0
and
self
.
_grads_generated
[
contiguous_idx
-
1
]:
contiguous_idx
-=
1
if
contiguous_idx
<
num_grads
and
self
.
_grads_info
[
contiguous_idx
][
"param_offset"
]
<=
(
self
.
_current_block
-
1
)
*
self
.
_block_size
:
self
.
_current_block
-=
1
start
=
self
.
_current_block
*
self
.
_block_size
end
=
(
self
.
_current_block
+
1
)
*
self
.
_block_size
flush_block
=
[
start
,
end
]
return
flush_block
def
__launch_step_kernel
(
self
,
p
,
p_copy
,
m
,
v
,
g
):
combined_scale
=
self
.
_global_scale
if
self
.
_param_group
[
'max_grad_norm'
]
>
0
and
math
.
isfinite
(
self
.
L2_grad_norm
):
combined_scale
=
self
.
_param_group
[
'max_grad_norm'
]
/
(
self
.
L2_grad_norm
/
self
.
_global_scale
+
1e-6
)
combined_scale
=
self
.
_global_scale
/
min
(
1
,
combined_scale
)
bias_correction
=
1
if
self
.
_param_group
[
'bias_correction'
]
else
0
beta1
,
beta2
=
self
.
_param_group
[
'betas'
]
fused_adam_cuda
.
reversible_adam
(
p
,
p_copy
,
m
,
v
,
g
,
self
.
_param_group
[
'lr'
],
beta1
,
beta2
,
self
.
_param_group
[
'eps'
],
combined_scale
,
self
.
_param_state
[
'step'
]
+
1
,
self
.
eps_mode
,
bias_correction
,
self
.
_param_group
[
'weight_decay'
])
def
_flatten_grad_mt
(
self
,
scale
):
if
self
.
_flat_mt
and
len
(
self
.
_grads
)
>
0
:
self
.
_overflow_buf
.
zero_
()
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
self
.
_overflow_buf
,
list
(
zip
(
*
self
.
_grads
)),
scale
)
self
.
_grads
=
[]
def
_do_overlapped_reduction
(
self
,
param_i
,
param_grads_size
,
param_offset
,
param
):
# handle overlapped reductions
if
self
.
_flat_mt
:
self
.
_grads
.
append
(
(
param
.
grad
,
self
.
_individual_flat_grads
[
param_i
])
)
else
:
torch
.
div
(
param
.
grad
,
self
.
_world_size
if
self
.
_predivide
else
1.0
,
out
=
self
.
_individual_flat_grads
[
param_i
])
self
.
_grads_generated
[
param_i
]
=
True
if
not
self
.
_last_step
and
self
.
_overlap_reductions
:
flush_block
=
self
.
_get_flush_block
()
while
flush_block
:
block_id
=
flush_block
[
0
]
//
self
.
_block_size
self
.
_dwu_st
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
_dwu_st
):
self
.
_flatten_grad_mt
(
1.0
/
self
.
_world_size
if
self
.
_predivide
else
1.0
)
torch
.
distributed
.
all_reduce
(
self
.
_flat_grads_blocks
[
block_id
])
if
block_id
==
0
:
self
.
_l2_grad_norm_st
.
wait_stream
(
self
.
_dwu_st
)
with
torch
.
cuda
.
stream
(
self
.
_l2_grad_norm_st
):
self
.
_L2_grad_norm
=
self
.
_flat_grads
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
).
item
()
flush_block
=
self
.
_get_flush_block
()
def
set_global_scale
(
self
,
global_scale
):
"""Set global scale.
"""
self
.
_global_scale
=
global_scale
@
property
def
global_scale
(
self
):
return
self
.
_global_scale
@
property
def
L2_grad_norm
(
self
):
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_l2_grad_norm_st
)
return
self
.
_L2_grad_norm
def
complete_reductions
(
self
):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
if
self
.
_last_step
:
# zero out gradients that have not been completed yet
for
param_i
,
flat_grad
in
enumerate
(
self
.
_individual_flat_grads
):
if
not
self
.
_grads_generated
[
param_i
]:
flat_grad
.
zero_
()
self
.
_grads_generated
[
param_i
]
=
True
if
self
.
_last_step
or
not
self
.
_overlap_reductions
:
# nothing done so far, run full pipeline after reductions
self
.
_dwu_st
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
_dwu_st
):
self
.
_flatten_grad_mt
(
1.0
/
self
.
_world_size
if
self
.
_predivide
else
1.0
)
torch
.
distributed
.
all_reduce
(
self
.
_flat_grads
)
self
.
_l2_grad_norm_st
.
wait_stream
(
self
.
_dwu_st
)
with
torch
.
cuda
.
stream
(
self
.
_l2_grad_norm_st
):
self
.
_L2_grad_norm
=
self
.
_flat_grads
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
).
item
()
self
.
_current_block
=
self
.
_num_blocks
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
def
step
(
self
,
closure
=
None
,
skip_overflow_check
=
False
):
loss
=
None
if
closure
is
not
None
:
loss
=
closure
()
with
torch
.
cuda
.
stream
(
self
.
_dwu_st
):
self
.
__launch_step_kernel
(
self
.
_fp32_p
,
self
.
_flat_params_shards
[
self
.
_rank_in_group
],
self
.
_fp32_m
,
self
.
_fp32_v
,
self
.
_flat_grads_shards
[
self
.
_rank_in_group
])
torch
.
distributed
.
all_gather
(
self
.
_flat_params_shards
,
self
.
_flat_params_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ag_pg
,
no_copy
=
True
)
for
p
in
self
.
_model_params
:
self
.
state
[
p
][
'step'
]
+=
1
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_dwu_st
)
return
loss
tests/L0/run_optimizers/test_dist_adam.py
View file @
5ffb22d0
import
argparse
import
argparse
import
os
import
random
import
random
import
sys
import
torch
import
torch
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
apex
import
amp
from
apex.optimizers
import
FusedAdam
from
apex.contrib.optimizers.distributed_fused_adam
import
DistributedFusedAdam
from
apex.contrib.optimizers.distributed_fused_adam
import
DistributedFusedAdam
class
TestModel
(
torch
.
nn
.
Module
):
class
TestModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
args
):
def
__init__
(
self
,
args
):
super
(
TestModel
,
self
).
__init__
()
super
(
TestModel
,
self
).
__init__
()
self
.
linear
=
torch
.
nn
.
Sequential
(
*
[
self
.
linear
=
torch
.
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
args
.
dim
,
args
.
dim
,
bias
=
args
.
bias
)
for
_
in
range
(
args
.
layers
)])
torch
.
nn
.
Linear
(
args
.
dim
,
args
.
dim
)
for
_
in
range
(
args
.
layers
)
])
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
self
.
linear
(
x
)
y
=
0
for
i
,
l
in
enumerate
(
self
.
linear
):
y
+=
(
i
+
1
)
*
l
(
x
)
return
y
def
setup
(
args
):
def
setup
(
args
):
## Model
ref_model
=
TestModel
(
args
).
cuda
()
dist_model
=
TestModel
(
args
).
cuda
()
# Same weights
# Construct models with same parameters
ref_model
=
TestModel
(
args
).
float
().
cuda
()
dist_model
=
TestModel
(
args
).
float
().
cuda
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
dp
,
rp
in
zip
(
dist_model
.
parameters
(),
ref_model
.
parameters
()):
for
ref_param
,
dist_param
in
zip
(
dist_model
.
parameters
(),
dp
.
data
.
copy_
(
rp
.
data
)
ref_model
.
parameters
()):
dist_param
.
data
.
copy_
(
ref_param
.
data
)
dist_model
=
dist_model
.
half
()
ref_model
=
torch
.
nn
.
parallel
.
DistributedDataParallel
(
ref_model
,
device_ids
=
[
args
.
rank
],
## Optimizer
output_device
=
args
.
rank
,
# same hyperparameters
)
ref_opt_args
=
{
'lr'
:
1e-3
,
'eps'
:
1e-6
,
'weight_decay'
:
0.01
}
ref_opt
=
FusedAdam
(
ref_model
.
parameters
(),
**
ref_opt_args
)
# Construct optimizers with same hyperparameters
optim_args
=
{
'lr'
:
1
,
'betas'
:
(
0.5
,
0.75
),
'eps'
:
0.1
,
'weight_decay'
:
0.1
}
dist_opt_args
=
ref_opt_args
.
copy
()
ref_optim
=
torch
.
optim
.
AdamW
(
dist_opt_args
.
update
(
{
'overlap_reductions'
:
False
}
)
[
dist_opt_args
.
update
(
{
'process_group_size'
:
args
.
n_gpu
}
)
{
'params'
:
list
(
ref_model
.
parameters
())[
1
::
2
],
'lr'
:
0.5
},
dist_opt_args
.
update
(
{
'dwu_group_size'
:
args
.
dwu_group_size
}
)
{
'params'
:
list
(
ref_model
.
parameters
())[
0
::
2
]},
dist_opt_args
.
update
(
{
'dwu_num_blocks'
:
1
}
)
],
dist_opt_args
.
update
(
{
'dwu_num_chunks'
:
1
}
)
**
optim_args
,
dist_opt
=
DistributedFusedAdam
(
dist_model
.
parameters
(),
**
dist_opt_args
)
)
dist_opt
.
set_global_scale
(
1.
)
dist_optim
=
DistributedFusedAdam
(
[
## amp-init
{
'params'
:
list
(
dist_model
.
parameters
())[
1
::
2
],
'lr'
:
0.5
},
amp_args
=
{
'loss_scale'
:
'dynamic'
,
'opt_level'
:
'O2'
}
{
'params'
:
list
(
dist_model
.
parameters
())[
0
::
2
]},
ref_model
,
ref_opt
=
amp
.
initialize
(
ref_model
,
ref_opt
,
**
amp_args
)
],
bucket_cap_mb
=
71
/
(
4
*
1024
*
1024
),
**
optim_args
,
## DDP
)
ref_model
=
DDP
(
ref_model
,
device_ids
=
[
args
.
rank
])
with
torch
.
no_grad
():
return
ref_model
,
ref_optim
,
dist_model
,
dist_optim
for
dp
in
dist_model
.
parameters
():
torch
.
distributed
.
broadcast
(
dp
.
data
,
src
=
0
)
for
rp
in
ref_model
.
parameters
():
torch
.
distributed
.
broadcast
(
rp
.
data
,
src
=
0
)
torch
.
cuda
.
synchronize
()
torch
.
distributed
.
barrier
()
if
get_rank
()
==
0
:
print
(
f
'dist opt with
{
args
.
n_gpu
}
GPUs'
)
return
ref_model
,
ref_opt
,
dist_model
,
dist_opt
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=-
1
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=-
1
)
parser
.
add_argument
(
'--steps'
,
type
=
int
,
default
=
20
)
parser
.
add_argument
(
'--steps'
,
type
=
int
,
default
=
3
)
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
32
)
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
7
)
parser
.
add_argument
(
'--layers'
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
'--layers'
,
type
=
int
,
default
=
11
)
parser
.
add_argument
(
'--bias'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--atol'
,
type
=
float
,
default
=
1e-5
)
parser
.
add_argument
(
'--atol'
,
type
=
float
,
default
=
1e-3
)
parser
.
add_argument
(
'--rtol'
,
type
=
float
,
default
=
1e-5
)
parser
.
add_argument
(
'--rtol'
,
type
=
float
,
default
=
1
)
parser
.
add_argument
(
'--dwu_group_size'
,
type
=
float
,
default
=
1
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
return
args
return
args
def
setup_env
(
args
):
def
setup_env
(
args
):
torch
.
cuda
.
set_device
(
args
.
local_rank
)
# Initialize NCCL
local_rank
=
args
.
local_rank
if
local_rank
<
0
:
local_rank
=
int
(
os
.
getenv
(
'LOCAL_RANK'
,
0
))
torch
.
cuda
.
set_device
(
local_rank
%
torch
.
cuda
.
device_count
())
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
,
init_method
=
'env://'
)
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
,
init_method
=
'env://'
)
args
.
rank
=
torch
.
distributed
.
get_rank
()
args
.
rank
=
torch
.
distributed
.
get_rank
()
args
.
n_gpu
=
torch
.
distributed
.
get_world_size
()
args
.
world_size
=
torch
.
distributed
.
get_world_size
()
seed
=
42
+
get_rank
()
# Initialize RNG
seed
=
42
+
args
.
rank
random
.
seed
(
seed
)
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
return
args
return
args
def
get_rank
():
return
torch
.
distributed
.
get_rank
()
def
main
():
def
main
():
args
=
parse_args
()
args
=
parse_args
()
args
=
setup_env
(
args
)
args
=
setup_env
(
args
)
tol_args
=
{
'atol'
:
args
.
atol
,
'rtol'
:
args
.
rtol
}
torch
.
set_printoptions
(
precision
=
16
)
torch
.
set_printoptions
(
precision
=
16
)
ref_model
,
ref_opt
,
dist_model
,
dist_opt
=
setup
(
args
)
def
assert_allclose
(
ref_x
,
dist_x
,
message
):
message
=
(
# lazy_init not called yet, initialize stash
f
'Rank
{
args
.
rank
}
:
{
message
}
\n
'
stash
=
ref_opt
.
_amp_stash
f
'Reference Adam:
{
ref_x
}
\n
'
stash
.
all_fp16_params
,
stash
.
all_fp32_from_fp16_params
=
[],
[]
f
'Distributed Adam:
{
dist_x
}
\n
'
f
'Relative error:
{
torch
.
abs
((
ref_x
-
dist_x
)
/
ref_x
)
}
\n
'
# make sure everything from _first_step_init_ is ready before training
)
# e.g. registering allreduce_hook
assert
torch
.
allclose
(
ref_x
,
dist_x
,
atol
=
args
.
atol
,
rtol
=
args
.
rtol
),
message
# so that gradients are copied/reduced when necessary
dist_opt
.
_init_everything
()
# Train model with data-parallelism and ZeRO
ref_model
,
ref_optim
,
dist_model
,
dist_optim
=
setup
(
args
)
for
i
in
range
(
args
.
steps
):
for
step
in
range
(
args
.
steps
):
x_ref
=
torch
.
randn
(
args
.
batch
,
args
.
dim
,
dtype
=
torch
.
half
).
cuda
().
requires_grad_
(
True
)
x_dist
=
x_ref
.
clone
().
detach
().
requires_grad_
(
True
)
# Synthetic data
x
=
torch
.
randn
(
args
.
batch
,
args
.
dim
).
cuda
()
if
get_rank
()
==
0
:
dy
=
torch
.
randn_like
(
x
).
cuda
()
print
(
f
'[
{
i
}
] Checking input'
)
#print("x_ref:", x_ref.flatten()[:10])
# Reference implementation
#print("x_dist:", x_dist.flatten()[:10])
ref_optim
.
zero_grad
()
assert
(
torch
.
allclose
(
x_ref
,
x_dist
,
**
tol_args
))
x_ref
=
x
.
detach
().
clone
().
requires_grad_
(
True
)
y_ref
=
ref_model
(
x_ref
)
y_ref
.
backward
(
dy
)
ref_optim
.
step
()
y_ref
=
ref_model
(
x_ref
).
half
()
# Distributed implementation
dist_optim
.
zero_grad
()
x_dist
=
x
.
detach
().
clone
().
requires_grad_
(
True
)
y_dist
=
dist_model
(
x_dist
)
y_dist
=
dist_model
(
x_dist
)
if
get_rank
()
==
0
:
print
(
f
'[
{
i
}
] Checking output'
)
#print("y_ref:", y_ref.flatten()[:10])
#print("y_dist:", y_dist.flatten()[:10])
assert
(
torch
.
allclose
(
y_ref
,
y_dist
,
**
tol_args
))
dy
=
torch
.
randn_like
(
y_ref
)
y_ref
.
backward
(
dy
)
y_dist
.
backward
(
dy
)
y_dist
.
backward
(
dy
)
dist_optim
.
step
()
if
get_rank
()
==
0
:
# Check values
print
(
f
'[
{
i
}
] Checking gradients'
)
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
assert
(
torch
.
allclose
(
x_ref
.
grad
,
x_dist
.
grad
,
**
tol_args
))
# gradient all-reduce within distributed optimizer
dist_opt
.
complete_reductions
()
if
get_rank
()
==
0
:
print
(
f
'[
{
i
}
] Stepping'
)
ref_opt
.
step
()
dist_opt
.
step
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
print
(
'Checking new weights'
)
assert_allclose
(
if
get_rank
()
==
0
:
y_ref
,
print
(
"ref param:"
,
ref_model
.
module
.
linear
[
0
].
weight
)
y_dist
,
print
(
"dist param:"
,
dist_model
.
linear
[
0
].
weight
)
f
'inconsistent output in step
{
step
}
'
,
)
for
i
,
(
rp
,
dp
)
in
enumerate
(
zip
(
ref_model
.
parameters
(),
dist_model
.
parameters
())):
assert_allclose
(
if
not
torch
.
allclose
(
rp
,
dp
,
**
tol_args
):
x_ref
.
grad
,
if
get_rank
()
==
0
:
x_dist
.
grad
,
print
(
f
'Rank:
{
get_rank
()
}
, Param:
{
i
}
'
)
f
'inconsistent input grad in step
{
step
}
'
,
print
(
f
'ref:
{
rp
.
sum
().
item
()
}
, dist:
{
dp
.
sum
().
item
()
}
'
)
)
print
(
rp
)
for
i
,
(
ref_param
,
dist_param
)
in
enumerate
(
zip
(
ref_model
.
parameters
(),
print
(
dp
)
dist_model
.
parameters
())):
assert_allclose
(
print
(
torch
.
abs
(
rp
-
dp
)
>
tol_args
[
'atol'
])
ref_param
,
sys
.
exit
(
0
)
dist_param
,
f
'inconsistent param
{
i
}
in step
{
step
}
'
,
# zero grads
)
for
rp
,
dp
in
zip
(
ref_model
.
parameters
(),
dist_model
.
parameters
()):
rp
.
grad
=
None
dp
.
grad
=
None
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment