Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
44f54712
Commit
44f54712
authored
Apr 28, 2020
by
Thor Johnsen
Browse files
Reduce CPU overhead, bigger step, all-gather
parent
f0448054
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
195 additions
and
241 deletions
+195
-241
apex/contrib/optimizers/distributed_fused_adam.py
apex/contrib/optimizers/distributed_fused_adam.py
+195
-241
No files found.
apex/contrib/optimizers/distributed_fused_adam.py
View file @
44f54712
...
...
@@ -68,6 +68,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
assert
(
len
(
self
.
param_groups
)
==
1
),
"More than one parameter group is not supported."
# Way to revert a step
# 3 -> undo kernel + double buffer (debug, print norm of difference)
# 2 -> double buffer fp32 parameters
...
...
@@ -94,12 +96,21 @@ class DistributedFusedAdam(torch.optim.Optimizer):
p_offset
=
0
p_i
=
0
self
.
_param_state
=
None
self
.
_model_params
=
[]
self
.
_grads_info
=
[]
for
group
in
self
.
param_groups
:
self
.
_param_group
=
group
for
p
in
group
[
'params'
]:
torch
.
distributed
.
broadcast
(
p
,
0
)
if
not
p
.
requires_grad
:
continue
self
.
_model_params
.
append
(
p
)
state
=
self
.
state
[
'p'
]
if
len
(
state
)
==
0
:
state
[
'step'
]
=
0
if
self
.
_param_state
is
None
:
self
.
_param_state
=
state
p_grads_size
=
p
.
numel
()
def
wrapper
(
param
,
param_i
,
param_grads_size
,
param_offset
):
def
allreduce_hook
(
grad
):
...
...
@@ -134,12 +145,85 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
_low_param_i
[
block_id
]
=
p_i
print
(
self
.
_low_param_i
)
self
.
_flat_grads
=
torch
.
zeros
([
self
.
_total_param_size
]).
half
().
cuda
()
self
.
_new_params
=
None
self
.
_fp32_p
=
None
self
.
_fp32_m
=
None
self
.
_fp32_v
=
None
self
.
_copy_to_fp32
=
False
self
.
_flat_grads
=
torch
.
zeros
([
self
.
_total_param_size
],
dtype
=
torch
.
float16
,
device
=
'cuda'
)
self
.
_new_params
=
torch
.
zeros
([
self
.
_total_param_size
],
dtype
=
torch
.
uint8
if
self
.
_e5m2_allgather
else
torch
.
float16
,
device
=
'cuda'
)
self
.
_mega_shard_size
=
self
.
_num_blocks
*
self
.
_num_chunks
*
self
.
_shard_size
self
.
_fp32_p
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_fp32_m
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_fp32_v
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
# FIXME: Rethink fp16 label since it's either uint8 or fp16
self
.
_fp16_p
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
uint8
if
self
.
_e5m2_allgather
else
torch
.
float16
,
device
=
'cuda'
)
self
.
_fp16_g
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float16
,
device
=
'cuda'
)
def
_flat_split
(
p
):
def
__blockify
(
p
):
return
[
p
[
block_id
*
self
.
_block_size
:(
block_id
+
1
)
*
self
.
_block_size
]
for
block_id
in
range
(
self
.
_num_blocks
)]
def
__chunkify
(
p
):
return
[
p
[
chunk_id
*
self
.
_chunk_size
:(
chunk_id
+
1
)
*
self
.
_chunk_size
]
for
chunk_id
in
range
(
self
.
_num_chunks
)]
def
__shardify
(
p
):
return
[
p
[
shard_id
*
self
.
_shard_size
:(
shard_id
+
1
)
*
self
.
_shard_size
]
for
shard_id
in
range
(
self
.
_group_size
)]
list_of_blocks
=
__blockify
(
self
.
_flat_grads
)
list_of_list_of_chunks
=
[
__chunkify
(
block
)
for
block
in
list_of_blocks
]
list_of_list_of_list_of_shards
=
[[
__shardify
(
chunk
)
for
chunk
in
chunks
]
for
chunks
in
list_of_list_of_chunks
]
return
list_of_blocks
,
list_of_list_of_chunks
,
list_of_list_of_list_of_shards
self
.
_flat_grads_blocks
,
self
.
_flat_grads_chunks
,
self
.
_flat_grads_shards
=
_flat_split
(
self
.
_flat_grads
)
def
_full_packed_split
(
p
):
def
__shardify
(
p
):
return
[
p
[
mega_shard
*
self
.
_mega_shard_size
:(
mega_shard
+
1
)
*
self
.
_mega_shard_size
]
for
mega_shard
in
range
(
self
.
_group_size
)]
def
__blockify
(
p
):
return
[
p
[
block_id
*
self
.
_num_chunks
*
self
.
_shard_size
:(
block_id
+
1
)
*
self
.
_num_chunks
*
self
.
_shard_size
]
for
block_id
in
range
(
self
.
_num_blocks
)]
def
__chunkify
(
p
):
return
[
p
[
chunk_id
*
self
.
_shard_size
:(
chunk_id
+
1
)
*
self
.
_shard_size
]
for
chunk_id
in
range
(
self
.
_num_chunks
)]
list_of_mega_shards
=
__shardify
(
p
)
list_of_list_of_mega_blocks
=
[
__blockify
(
mega_shard
)
for
mega_shard
in
list_of_mega_shards
]
list_of_list_of_list_of_mega_chunks
=
[[
__chunkify
(
mega_block
)
for
mega_block
in
mega_blocks
]
for
mega_blocks
in
list_of_list_of_mega_blocks
]
return
list_of_mega_shards
,
list_of_list_of_mega_blocks
,
list_of_list_of_list_of_mega_chunks
self
.
_new_params_mega_shards
,
self
.
_new_params_mega_blocks
,
self
.
_new_params_mega_chunks
=
_full_packed_split
(
self
.
_new_params
)
def
_packed_split
(
p
):
def
__packed_blockify
(
p
):
packed_block_size
=
self
.
_num_chunks
*
self
.
_shard_size
return
[
p
[
block_id
*
packed_block_size
:(
block_id
+
1
)
*
packed_block_size
]
for
block_id
in
range
(
self
.
_num_blocks
)]
def
__packed_chunkify
(
p
):
# in the packed format, each chunk contains one shard, so packed_chunk_size == self._shard_size
return
[
p
[
chunk_id
*
self
.
_shard_size
:(
chunk_id
+
1
)
*
self
.
_shard_size
]
for
chunk_id
in
range
(
self
.
_num_chunks
)]
list_of_blocks
=
__packed_blockify
(
p
)
list_of_list_of_chunks
=
[
__packed_chunkify
(
block
)
for
block
in
list_of_blocks
]
return
list_of_blocks
,
list_of_list_of_chunks
self
.
_fp32_p_blocks
,
self
.
_fp32_p_chunks
=
_packed_split
(
self
.
_fp32_p
)
self
.
_fp32_m_blocks
,
self
.
_fp32_m_chunks
=
_packed_split
(
self
.
_fp32_m
)
self
.
_fp32_v_blocks
,
self
.
_fp32_v_chunks
=
_packed_split
(
self
.
_fp32_v
)
self
.
_fp16_p_blocks
,
self
.
_fp16_p_chunks
=
_packed_split
(
self
.
_fp16_p
)
self
.
_fp16_g_blocks
,
self
.
_fp16_g_chunks
=
_packed_split
(
self
.
_fp16_g
)
# This paragraph does two things:
# 1) Copy model parameters into master buffer
# 2) Create tensor lists for unpacking new parameter tensor after all-gather
self
.
_packed_flat_to_model_params
=
[]
for
shard_id
in
range
(
self
.
_group_size
):
for
block_id
in
range
(
self
.
_num_blocks
):
for
chunk_id
in
range
(
self
.
_num_chunks
):
flat_shard_start
=
(((
block_id
*
self
.
_num_chunks
+
chunk_id
)
*
self
.
_group_size
)
+
shard_id
)
*
self
.
_shard_size
flat_shard_end
=
flat_shard_start
+
self
.
_shard_size
for
p
,
grads_info
in
zip
(
self
.
_model_params
,
self
.
_grads_info
):
flat_grad_start
=
grads_info
[
"param_offset"
]
flat_grad_end
=
flat_grad_start
+
grads_info
[
"param_grads_size"
]
clipped_start
=
(
lambda
a
,
b
:
a
if
a
>
b
else
b
)(
flat_grad_start
,
flat_shard_start
)
clipped_end
=
(
lambda
a
,
b
:
a
if
a
<
b
else
b
)(
flat_grad_end
,
flat_shard_end
)
if
clipped_start
<
clipped_end
:
grad_offset
=
clipped_start
-
flat_grad_start
grad_length
=
clipped_end
-
clipped_start
shard_offset
=
clipped_start
-
flat_shard_start
model_param_fragment
=
p
.
view
(
-
1
)[
grad_offset
:
grad_offset
+
grad_length
]
new_param_packed_fragment
=
self
.
_new_params_mega_chunks
[
shard_id
][
block_id
][
chunk_id
][
shard_offset
:
shard_offset
+
grad_length
]
self
.
_packed_flat_to_model_params
.
append
(
(
new_param_packed_fragment
,
model_param_fragment
)
)
if
shard_id
==
self
.
_rank_in_group
:
# copy model parameters into master buffer
master_param_fragment
=
self
.
_fp32_p_chunks
[
block_id
][
chunk_id
][
shard_offset
:
shard_offset
+
grad_length
]
print
(
"model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s"
%
(
str
(
model_param_fragment
.
size
()),
str
(
new_param_packed_fragment
.
size
()),
str
(
master_param_fragment
.
size
())))
master_param_fragment
.
copy_
(
model_param_fragment
)
p_in
,
p_out
=
zip
(
*
self
.
_packed_flat_to_model_params
)
self
.
_packed_flat_to_model_params
=
[
p_in
,
p_out
]
self
.
_distributed_weight_update
=
distributed_weight_update
# Is this still needed?
self
.
_num_rs_pg
=
dwu_num_rs_pg
...
...
@@ -221,64 +305,91 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def
_pipeline_block_reductions
(
self
,
block_id
):
self
.
_flatten_grad_mt
(
1.0
/
self
.
_world_size
if
self
.
_predivide
else
1.0
)
start
=
block_id
*
self
.
_block_size
end
=
start
+
self
.
_block_size
grad_block
=
self
.
_flat_grads
[
start
:
end
]
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works
=
[
None
]
*
self
.
_num_chunks
for
chunk
in
range
(
self
.
_num_chunks
):
glob_chunk
=
block_id
*
self
.
_num_chunks
+
chunk
grad_chunk
=
grad_block
[
chunk
*
self
.
_chunk_size
:(
chunk
+
1
)
*
self
.
_chunk_size
]
grad_shards
=
[
grad_chunk
[
i
*
self
.
_shard_size
:(
i
+
1
)
*
self
.
_shard_size
]
for
i
in
range
(
self
.
_group_size
)]
rs_stream
=
self
.
_rs_st
[
glob_chunk
%
self
.
_num_rs_pg
]
for
chunk_id
in
range
(
self
.
_num_chunks
):
glob_chunk_id
=
block_id
*
self
.
_num_chunks
+
chunk_id
rs_stream
=
self
.
_rs_st
[
glob_chunk_id
%
self
.
_num_rs_pg
]
rs_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
rs_stream
):
work
=
torch
.
distributed
.
reduce_scatter
(
grad_shards
[
self
.
_rank_in_group
],
grad_shards
,
group
=
self
.
_rs_pg
[
glob_chunk
%
self
.
_num_rs_pg
],
async_op
=
True
,
no_copy
=
True
)
if
self
.
_num_groups
>
1
:
ar_stream
=
self
.
_ar_st
[
glob_chunk
%
self
.
_num_ar_pg
]
with
torch
.
cuda
.
stream
(
ar_stream
):
work
.
wait
()
work
=
torch
.
distributed
.
all_reduce
(
grad_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ar_pg
[
glob_chunk
%
self
.
_num_ar_pg
],
async_op
=
True
)
works
[
chunk
]
=
work
if
self
.
_compute_L2_grad_norm
:
for
chunk
in
range
(
self
.
_num_chunks
):
grad_chunk
=
grad_block
[
chunk
*
self
.
_chunk_size
:(
chunk
+
1
)
*
self
.
_chunk_size
]
grad_shards
=
[
grad_chunk
[
i
*
self
.
_shard_size
:(
i
+
1
)
*
self
.
_shard_size
]
for
i
in
range
(
self
.
_group_size
)]
with
torch
.
cuda
.
stream
(
self
.
_l2_grad_norm_st
):
works
[
chunk
].
wait
()
l2_grad_sq
=
grad_shards
[
self
.
_rank_in_group
].
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
if
block_id
+
1
==
self
.
_num_blocks
and
chunk
==
0
:
self
.
_L2_grad_norm
=
l2_grad_sq
else
:
self
.
_L2_grad_norm
+=
l2_grad_sq
if
block_id
==
0
and
chunk
+
1
==
self
.
_num_chunks
:
torch
.
distributed
.
all_reduce
(
self
.
_L2_grad_norm
,
group
=
self
.
_l2_grad_norm_pg
)
self
.
_L2_grad_norm
.
sqrt_
()
works
[
chunk_id
]
=
torch
.
distributed
.
reduce_scatter
(
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
],
self
.
_flat_grads_shards
[
block_id
][
chunk_id
],
group
=
self
.
_rs_pg
[
glob_chunk_id
%
self
.
_num_rs_pg
],
async_op
=
True
,
no_copy
=
True
)
# Reduction across nodes for each rank
if
self
.
_num_groups
>
1
:
for
chunk_id
in
range
(
self
.
_num_chunks
):
glob_chunk_id
=
block_id
*
self
.
_num_chunks
+
chunk_id
ar_stream
=
self
.
_ar_st
[
glob_chunk_id
%
self
.
_num_ar_pg
]
with
torch
.
cuda
.
stream
(
ar_stream
):
works
[
chunk_id
].
wait
()
works
[
chunk_id
]
=
torch
.
distributed
.
all_reduce
(
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
],
group
=
self
.
_ar_pg
[
glob_chunk_id
%
self
.
_num_ar_pg
],
async_op
=
True
)
self
.
_reductions_works
[
block_id
]
=
works
def
_pipeline_block_step
(
self
,
block_id
):
if
self
.
_new_params
is
None
:
self
.
_new_params
=
torch
.
zeros_like
(
self
.
_flat_grads
,
dtype
=
torch
.
uint8
if
self
.
_e5m2_allgather
else
self
.
_flat_grads
.
dtype
)
start
=
block_id
*
self
.
_block_size
end
=
start
+
self
.
_block_size
new_params_block
=
self
.
_new_params
[
start
:
end
]
# Optionally compute L2 grad norm
if
self
.
_compute_L2_grad_norm
and
block_id
==
0
:
with
torch
.
cuda
.
stream
(
self
.
_l2_grad_norm_st
):
for
block_id
in
range
(
self
.
_num_blocks
):
for
chunk_id
in
range
(
self
.
_num_chunks
):
self
.
_reductions_works
[
block_id
][
chunk_id
].
wait
()
# Since the packed format is contiguous after reductions, only one norm is needed
self
.
_L2_grad_norm
=
self
.
_fp16_g
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
torch
.
distributed
.
all_reduce
(
self
.
_L2_grad_norm
,
group
=
self
.
_l2_grad_norm_pg
)
self
.
_L2_grad_norm
.
sqrt_
()
def
__launch_step_kernel
(
self
,
p
,
p_copy
,
m
,
v
,
g
):
combined_scale
=
self
.
_global_scale
if
self
.
_param_group
[
'max_grad_norm'
]
>
0
and
math
.
isfinite
(
self
.
L2_grad_norm
):
combined_scale
=
self
.
_param_group
[
'max_grad_norm'
]
/
(
self
.
L2_grad_norm
/
self
.
_global_scale
+
1e-6
)
combined_scale
=
self
.
_global_scale
/
min
(
1
,
combined_scale
)
bias_correction
=
1
if
self
.
_param_group
[
'bias_correction'
]
else
0
beta1
,
beta2
=
self
.
_param_group
[
'betas'
]
fused_adam_cuda
.
adam
(
p
,
p_copy
,
m
,
v
,
g
,
self
.
_param_group
[
'lr'
],
beta1
,
beta2
,
self
.
_param_group
[
'eps'
],
combined_scale
,
self
.
_param_state
[
'step'
]
+
1
,
self
.
eps_mode
,
bias_correction
,
self
.
_param_group
[
'weight_decay'
])
works
=
[
None
]
*
self
.
_num_chunks
for
chunk
in
range
(
self
.
_num_chunks
):
glob_chunk
=
block_id
*
self
.
_num_chunks
+
chunk
new_params_chunk
=
new_params_block
[
chunk
*
self
.
_chunk_size
:(
chunk
+
1
)
*
self
.
_chunk_size
]
new_params_shards
=
[
new_params_chunk
[
i
*
self
.
_shard_size
:(
i
+
1
)
*
self
.
_shard_size
]
for
i
in
range
(
self
.
_group_size
)]
ag_stream
=
self
.
_ag_st
[
glob_chunk
%
self
.
_num_ag_pg
]
with
torch
.
cuda
.
stream
(
ag_stream
):
self
.
_reductions_works
[
block_id
][
chunk
].
wait
()
self
.
_partial_step_single_shard
(
block_id
,
chunk
)
work
=
torch
.
distributed
.
all_gather
(
new_params_shards
,
new_params_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ag_pg
[
glob_chunk
%
self
.
_num_ag_pg
],
async_op
=
True
,
no_copy
=
True
)
works
[
chunk
]
=
work
self
.
_allgather_works
[
block_id
]
=
works
def
_pipeline_block_step
(
self
,
block_id
):
# Call step kernel once per block
ag_stream
=
self
.
_ag_st
[
block_id
%
self
.
_num_ag_pg
]
with
torch
.
cuda
.
stream
(
ag_stream
):
for
chunk_id
in
range
(
self
.
_num_chunks
):
self
.
_reductions_works
[
block_id
][
chunk_id
].
wait
()
self
.
__launch_step_kernel
(
self
.
_fp32_p_blocks
[
block_id
],
self
.
_fp16_p_blocks
[
block_id
],
self
.
_fp32_m_blocks
[
block_id
],
self
.
_fp32_v_blocks
[
block_id
],
self
.
_fp16_g_blocks
[
block_id
])
# Call all-gather once per step.
# FIXME: Determine which is faster, one all-gather per block or a single all-gather at end
if
block_id
==
0
:
for
other_ag_stream
in
self
.
_ag_st
:
self
.
_completion_st
.
wait_stream
(
other_ag_stream
)
with
torch
.
cuda
.
stream
(
self
.
_completion_st
):
torch
.
distributed
.
all_gather
(
self
.
_new_params_mega_shards
,
self
.
_fp16_p
,
group
=
self
.
_ag_pg
[
0
],
no_copy
=
True
)
def
_pipeline_step
(
self
):
# Call step kernel once per step
# Call all-gather once per step
with
torch
.
cuda
.
stream
(
self
.
_completion_st
):
for
block_id
in
range
(
self
.
_num_blocks
):
for
chunk_id
in
range
(
self
.
_num_chunks
):
self
.
_reductions_works
[
block_id
][
chunk_id
].
wait
()
self
.
__launch_step_kernel
(
self
.
_fp32_p
,
self
.
_fp16_p
,
self
.
_fp32_m
,
self
.
_fp32_v
,
self
.
_fp16_g
)
torch
.
distributed
.
all_gather
(
self
.
_new_params_mega_shards
,
self
.
_fp16_p
,
group
=
self
.
_ag_pg
[
0
],
no_copy
=
True
)
def
_flatten_grad_mt
(
self
,
scale
):
if
self
.
_flat_mt
:
...
...
@@ -360,145 +471,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
else
:
return
None
# Distributed weight update algorithm:
# Model parameters are kept as-is.
# Gradients are flattened during backprop.
# Reductions are done with an intra-node reduce-scatter followed by an inter-node all-reduce.
# Step function is sharded and the shards are assembled with an intra-node all-gather.
# Sharded step function needs internal fp32 buffers for p, m and v.
# To save memory, we allocate the fp32 buffers to cover only the shards local GPU will update.
# This means we have to play around with indexes, which requires knowledge of block and shard number.
# Implement a method that performs a partial update of a single shard within a single block.
def
_partial_step_single_shard
(
self
,
block_id
,
chunk_id
,
undo
=
False
):
"""Perform step function for a single shard.
Arguments:
block_id (integer): Block index of shard [0,self._num_blocks>
undo (boolean, optional): If True, undo effect of previously called partial step.
"""
shard_id
=
self
.
_rank_in_group
shard_start
=
(
block_id
*
self
.
_num_chunks
+
chunk_id
)
*
self
.
_chunk_size
+
shard_id
*
self
.
_shard_size
shard_end
=
shard_start
+
self
.
_shard_size
if
self
.
_fp32_p
is
None
:
assert
(
not
undo
),
"Tried to undo step before calling step."
# Allocate fp32 buffers on demand. Note that we don't make these part of the state
# since each rank only has partial buffers.
# To-Do:
self
.
_fp32_p
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_num_chunks
*
self
.
_shard_size
]).
float
().
cuda
()
self
.
_fp32_m
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_num_chunks
*
self
.
_shard_size
]).
float
().
cuda
()
self
.
_fp32_v
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_num_chunks
*
self
.
_shard_size
]).
float
().
cuda
()
if
self
.
_revert_method
>
1
:
self
.
_fp32_backup_p
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_num_chunks
*
self
.
_shard_size
]).
float
().
cuda
()
self
.
_fp32_backup_m
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_num_chunks
*
self
.
_shard_size
]).
float
().
cuda
()
self
.
_fp32_backup_v
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_num_chunks
*
self
.
_shard_size
]).
float
().
cuda
()
self
.
_copy_to_fp32
=
True
step
=
None
param_i
=
0
for
group
in
self
.
param_groups
:
# compute combined scale factor for this group
combined_scale
=
self
.
_global_scale
if
group
[
'max_grad_norm'
]
>
0
and
math
.
isfinite
(
self
.
L2_grad_norm
):
combined_scale
=
group
[
'max_grad_norm'
]
/
(
self
.
L2_grad_norm
/
self
.
_global_scale
+
1e-6
)
combined_scale
=
self
.
_global_scale
/
min
(
1
,
combined_scale
)
bias_correction
=
1
if
group
[
'bias_correction'
]
else
0
group_start
=
-
1
group_end
=
-
2
for
p
in
group
[
'params'
]:
if
not
p
.
requires_grad
:
continue
#if p.grad.is_sparse:
# raise RuntimeError('FusedAdam does not support sparse gradients, please consider SparseAdam instead')
state
=
self
.
state
[
p
]
if
len
(
state
)
==
0
:
state
[
'step'
]
=
0
if
step
is
None
:
# all we want from state at this point is state['step'], which should be the same for all p
step
=
state
[
'step'
]
nels
=
p
.
numel
()
offset
=
self
.
_grads_info
[
param_i
][
'param_offset'
]
param_i
+=
1
start
=
offset
end
=
start
+
nels
clipped_start
=
start
if
start
>=
shard_start
else
shard_start
clipped_end
=
end
if
end
<=
shard_end
else
shard_end
# check if this parameter contributes to shard
if
clipped_start
<
clipped_end
:
if
group_start
<
0
:
group_start
=
clipped_start
group_end
=
clipped_end
if
self
.
_copy_to_fp32
:
param_offset
=
clipped_start
-
shard_start
param_size
=
clipped_end
-
clipped_start
buffer_start
=
(
block_id
*
self
.
_num_chunks
+
chunk_id
)
*
self
.
_shard_size
+
param_offset
buffer_end
=
buffer_start
+
param_size
param_start
=
(
clipped_start
-
start
)
param_end
=
param_start
+
param_size
self
.
_fp32_p
[
buffer_start
:
buffer_end
].
copy_
(
p
.
view
(
-
1
)[
param_start
:
param_end
].
float
())
group_size
=
group_end
-
group_start
if
group_size
>
0
:
assert
(
step
is
not
None
),
"state['step'] is None for this parameter group"
group_offset
=
group_start
-
shard_start
group_shard_start
=
shard_start
+
group_offset
group_shard_end
=
group_shard_start
+
group_size
group_buffer_start
=
(
block_id
*
self
.
_num_chunks
+
chunk_id
)
*
self
.
_shard_size
+
group_offset
group_buffer_end
=
group_buffer_start
+
group_size
beta1
,
beta2
=
group
[
'betas'
]
if
undo
:
if
self
.
_revert_method
==
1
:
fused_adam_cuda
.
maybe_adam_undo
(
torch
.
empty
([
0
]),
self
.
_fp32_p
[
group_buffer_start
:
group_buffer_end
],
self
.
_fp32_m
[
group_buffer_start
:
group_buffer_end
],
self
.
_fp32_v
[
group_buffer_start
:
group_buffer_end
],
self
.
_flat_grads
[
group_shard_start
:
group_shard_end
],
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
combined_scale
,
step
+
1
,
# FIXME: Verify this should be step+1
self
.
eps_mode
,
bias_correction
,
group
[
'weight_decay'
])
elif
self
.
_revert_method
==
2
:
self
.
_fp32_p
[
group_buffer_start
:
group_buffer_end
].
copy_
(
self
.
_fp32_backup_p
[
group_buffer_start
:
group_buffer_end
])
self
.
_fp32_m
[
group_buffer_start
:
group_buffer_end
].
copy_
(
self
.
_fp32_backup_m
[
group_buffer_start
:
group_buffer_end
])
self
.
_fp32_v
[
group_buffer_start
:
group_buffer_end
].
copy_
(
self
.
_fp32_backup_v
[
group_buffer_start
:
group_buffer_end
])
elif
self
.
_revert_method
==
3
:
raise
RuntimeError
(
'revert_step debug option not implemented yet'
)
else
:
if
self
.
_revert_method
>
1
:
self
.
_fp32_backup_p
[
group_buffer_start
:
group_buffer_end
].
copy_
(
self
.
_fp32_p
[
group_buffer_start
:
group_buffer_end
])
self
.
_fp32_backup_m
[
group_buffer_start
:
group_buffer_end
].
copy_
(
self
.
_fp32_m
[
group_buffer_start
:
group_buffer_end
])
self
.
_fp32_backup_v
[
group_buffer_start
:
group_buffer_end
].
copy_
(
self
.
_fp32_v
[
group_buffer_start
:
group_buffer_end
])
fused_adam_cuda
.
adam
(
self
.
_fp32_p
[
group_buffer_start
:
group_buffer_end
],
self
.
_new_params
[
group_shard_start
:
group_shard_end
],
self
.
_fp32_m
[
group_buffer_start
:
group_buffer_end
],
self
.
_fp32_v
[
group_buffer_start
:
group_buffer_end
],
self
.
_flat_grads
[
group_shard_start
:
group_shard_end
],
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
combined_scale
,
step
+
1
,
self
.
eps_mode
,
bias_correction
,
group
[
'weight_decay'
])
def
complete_reductions
(
self
):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
...
...
@@ -521,17 +493,34 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if
self
.
_compute_L2_grad_norm
:
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_l2_grad_norm_st
)
self
.
_copy_to_fp32
=
False
self
.
_decomp_stats
=
None
self
.
_current_block
=
self
.
_num_blocks
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
def
revert_step
(
self
):
"""Revert effect of previously calling partial_step.
"""
for
block_id
in
range
(
self
.
_num_blocks
):
for
chunk
in
range
(
self
.
_num_chunks
):
self
.
_partial_step_single_shard
(
block_id
,
chunk
,
undo
=
True
)
# Call undo kernel once per step
combined_scale
=
self
.
_global_scale
if
self
.
_param_group
[
'max_grad_norm'
]
>
0
and
math
.
isfinite
(
self
.
L2_grad_norm
):
combined_scale
=
self
.
_param_group
[
'max_grad_norm'
]
/
(
self
.
L2_grad_norm
/
self
.
_global_scale
+
1e-6
)
combined_scale
=
self
.
_global_scale
/
min
(
1
,
combined_scale
)
bias_correction
=
1
if
self
.
_param_group
[
'bias_correction'
]
else
0
beta1
,
beta2
=
self
.
_param_group
[
'betas'
]
fused_adam_cuda
.
maybe_adam_undo
(
torch
.
empty
([
0
]),
self
.
_fp32_p
,
self
.
_fp32_m
,
self
.
_fp32_v
,
self
.
_fp16_g
,
self
.
_param_group
[
'lr'
],
beta1
,
beta2
,
self
.
_param_group
[
'eps'
],
combined_scale
,
self
.
_param_state
[
'step'
]
+
1
,
self
.
eps_mode
,
bias_correction
,
self
.
_param_group
[
'weight_decay'
])
def
step
(
self
,
closure
=
None
,
skip_overflow_check
=
False
):
loss
=
None
...
...
@@ -539,14 +528,9 @@ class DistributedFusedAdam(torch.optim.Optimizer):
loss
=
closure
()
if
self
.
_last_step
or
not
self
.
_overlap_reductions
or
not
self
.
_full_pipeline
:
for
block_id
in
range
(
self
.
_num_blocks
-
1
,
-
1
,
-
1
):
self
.
_pipeline_block_step
(
block_id
)
self
.
_pipeline_step
()
with
torch
.
cuda
.
stream
(
self
.
_completion_st
):
for
block_id
in
range
(
self
.
_num_blocks
-
1
,
-
1
,
-
1
):
for
chunk
in
range
(
self
.
_num_chunks
):
self
.
_allgather_works
[
block_id
][
chunk
].
wait
()
# Check for overflow
# Store state for loss scaler calculation
if
skip_overflow_check
:
...
...
@@ -559,40 +543,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
revert_step
()
else
:
# Copy self._new_params to model params
if
self
.
_e5m2_allgather
or
self
.
_do_not_flatten_model
:
p_in
=
[]
p_out
=
[]
with
torch
.
no_grad
():
param_i
=
0
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
if
not
p
.
requires_grad
:
continue
state
=
self
.
state
[
p
]
if
len
(
state
)
==
0
:
state
[
'step'
]
=
0
state
[
'step'
]
+=
1
nels
=
p
.
numel
()
offset
=
self
.
_grads_info
[
param_i
][
'param_offset'
]
if
self
.
_e5m2_allgather
or
self
.
_do_not_flatten_model
:
p_in
.
append
(
self
.
_new_params
[
offset
:
offset
+
nels
].
view_as
(
p
))
p_out
.
append
(
p
)
else
:
p
.
set_
(
self
.
_new_params
[
offset
:
offset
+
nels
].
view_as
(
p
))
param_i
+=
1
if
self
.
_e5m2_allgather
:
multi_tensor_applier
(
fused_adam_cuda
.
maybe_cast_mt
,
self
.
_overflow_buf
,
[
p_in
,
p_out
]);
elif
self
.
_do_not_flatten_model
:
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
self
.
_overflow_buf
,
[
p_in
,
p_out
],
1.0
);
if
not
self
.
_e5m2_allgather
and
not
self
.
_do_not_flatten_model
:
self
.
_new_params
=
None
multi_tensor_applier
(
fused_adam_cuda
.
maybe_cast_mt
,
self
.
_overflow_buf
,
self
.
_packed_flat_to_model_params
)
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_completion_st
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment