Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
a60bbe63
Commit
a60bbe63
authored
May 04, 2020
by
Thor Johnsen
Browse files
Try out different partition scheme
parent
7da28fc3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
357 additions
and
251 deletions
+357
-251
apex/contrib/optimizers/distributed_fused_adam_v2.py
apex/contrib/optimizers/distributed_fused_adam_v2.py
+357
-251
No files found.
apex/contrib/optimizers/distributed_fused_adam_v2.py
View file @
a60bbe63
...
...
@@ -68,67 +68,234 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
_predivide
=
predivide
assert
(
len
(
self
.
param_groups
)
==
1
),
"More than one parameter group is not supported."
# Way to revert a step
# 3 -> undo kernel + double buffer (debug, print norm of difference)
# 2 -> double buffer fp32 parameters
# 1 -> undo kernel
self
.
_revert_method
=
revert_method
if
self
.
_revert_method
>
1
:
print
(
"revert_method -> double buffer fp32 parameters, will consume more memory"
)
self
.
_last_step
=
False
self
.
_overlap_reductions
=
overlap_reductions
self
.
_global_scale
=
None
self
.
_num_blocks
=
dwu_num_blocks
self
.
_num_chunks
=
dwu_num_chunks
self
.
_predivide
=
predivide
self
.
_e5m2_allgather
=
e5m2_allgather
self
.
_do_not_flatten_model
=
do_not_flatten_model
self
.
_full_pipeline
=
full_pipeline
self
.
_compute_L2_grad_norm
=
compute_L2_grad_norm
self
.
_L2_grad_norm
=
None
self
.
_group_size
=
torch
.
cuda
.
device_count
()
if
dwu_group_size
<=
0
else
dwu_group_size
self
.
_group_id
=
torch
.
distributed
.
get_rank
()
//
self
.
_group_size
self
.
_num_groups
=
torch
.
distributed
.
get_world_size
()
//
self
.
_group_size
self
.
_rank_in_group
=
torch
.
distributed
.
get_rank
()
%
self
.
_group_size
self
.
_rank
=
torch
.
distributed
.
get_rank
()
self
.
_rank_in_group
=
self
.
_rank
%
self
.
_group_size
self
.
_world_size
=
torch
.
distributed
.
get_world_size
()
self
.
_num_groups
=
self
.
_world_size
//
self
.
_group_size
self
.
_rank_in_group
=
torch
.
distributed
.
get_rank
()
%
self
.
_group_size
p_offset
=
0
p_i
=
0
self
.
_param_state
=
None
self
.
_model_params
=
[]
self
.
_grads_info
=
[]
self
.
_grad_accs
=
[]
for
group
in
self
.
param_groups
:
self
.
_param_group
=
group
prev
=
None
for
p
in
group
[
'params'
]:
torch
.
distributed
.
broadcast
(
p
,
0
)
if
not
p
.
requires_grad
:
continue
self
.
_model_params
.
append
(
p
)
state
=
self
.
state
[
p
]
if
len
(
state
)
==
0
:
state
[
'step'
]
=
0
if
self
.
_param_state
is
None
:
self
.
_param_state
=
state
p_grads_size
=
p
.
numel
()
def
wrapper
(
param
,
param_i
,
param_grads_size
,
param_offset
):
def
allreduce_hook
(
grad
):
self
.
_do_overlapped_reduction
(
param_i
,
param_grads_size
,
param_offset
,
grad
)
param
.
register_hook
(
allreduce_hook
)
param_tmp
=
param
.
expand_as
(
param
)
grad_acc
=
param_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
def
allreduce_hook
(
*
unused
):
self
.
_do_overlapped_reduction
(
param_i
,
param_grads_size
,
param_offset
,
param
)
grad_acc
.
register_hook
(
allreduce_hook
)
self
.
_grad_accs
.
append
(
grad_acc
)
self
.
_grads_info
.
append
({
"param_grads_size"
:
p_grads_size
,
"param_offset"
:
p_offset
})
wrapper
(
p
,
p_i
,
p_grads_size
,
p_offset
)
p_offset
+=
p_grads_size
# enforce 128b alignment (64 * fp16)
p_offset
=
((
p_offset
+
63
)
//
64
)
*
64
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if
prev
is
not
None
and
(
prev
.
data_ptr
()
+
prev
.
numel
()
*
prev
.
element_size
()
!=
p
.
data_ptr
()):
p_offset
=
((
p_offset
+
63
)
//
64
)
*
64
prev
=
p
p_i
+=
1
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
self
.
_grads
=
[
None
]
*
len
(
self
.
_grads_info
)
self
.
_current_block
=
self
.
_group_size
self
.
_flat_mt
=
flat_mt
self
.
_grads
=
[]
if
self
.
_overlap_reductions
:
self
.
_current_block
=
self
.
_num_blocks
self
.
_net_total_param_size
=
p_offset
self
.
_total_param_size
=
p_offset
min_page_size
=
256
*
self
.
_group_size
self
.
_total_param_size
=
((
self
.
_total_param_size
+
min_page_size
-
1
)
//
min_page_size
)
*
min_page_size
self
.
_block_size
=
self
.
_total_param_size
//
self
.
_group_size
print
(
"self._net_total_param_size=%d, self._total_param_size=%d, min_page_size=%d, self._block_size=%d"
%
(
self
.
_net_total_param_size
,
self
.
_total_param_size
,
min_page_size
,
self
.
_block_size
))
self
.
_low_param_i
=
[
0
]
*
self
.
_group_size
for
block_id
in
range
(
self
.
_group_size
-
1
,
-
1
,
-
1
):
dwu_min_page_size
=
256
*
self
.
_num_blocks
*
self
.
_num_chunks
*
self
.
_group_size
self
.
_total_param_size
=
((
self
.
_total_param_size
+
dwu_min_page_size
-
1
)
//
dwu_min_page_size
)
*
dwu_min_page_size
self
.
_block_size
=
self
.
_total_param_size
//
self
.
_num_blocks
self
.
_shard_size
=
self
.
_block_size
//
self
.
_group_size
self
.
_chunk_size
=
self
.
_shard_size
//
self
.
_num_chunks
print
(
"self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._shard_size=%d, self._chunk_size=%d"
%
(
self
.
_net_total_param_size
,
self
.
_total_param_size
,
dwu_min_page_size
,
self
.
_block_size
,
self
.
_shard_size
,
self
.
_chunk_size
))
self
.
_low_param_i
=
[
0
]
*
self
.
_num_blocks
for
block_id
in
range
(
self
.
_num_blocks
-
1
,
-
1
,
-
1
):
p_i
=
len
(
self
.
_grads_info
)
-
1
while
p_i
>
0
and
self
.
_grads_info
[
p_i
][
"param_offset"
]
>
block_id
*
self
.
_block_size
:
p_i
-=
1
self
.
_low_param_i
[
block_id
]
=
p_i
print
(
self
.
_low_param_i
)
self
.
_global_scale
=
1.0
self
.
_fp32_p
=
None
self
.
_new_params
=
torch
.
zeros
(
size
=
[
self
.
_total_param_size
],
dtype
=
torch
.
uint8
).
cuda
()
self
.
_flat_grads
=
torch
.
zeros
(
size
=
[
self
.
_total_param_size
],
dtype
=
torch
.
float16
).
cuda
()
self
.
_flat_grads
=
torch
.
zeros
([
self
.
_total_param_size
],
dtype
=
torch
.
float16
,
device
=
'cuda'
)
self
.
_new_params
=
torch
.
zeros
([
self
.
_total_param_size
],
dtype
=
torch
.
uint8
if
self
.
_e5m2_allgather
else
torch
.
float16
,
device
=
'cuda'
)
self
.
_mega_shard_size
=
self
.
_num_blocks
*
self
.
_num_chunks
*
self
.
_chunk_size
self
.
_fp32_p
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_fp32_m
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_fp32_v
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
# FIXME: Rethink fp16 label since it's either uint8 or fp16
self
.
_fp16_p
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
uint8
if
self
.
_e5m2_allgather
else
torch
.
float16
,
device
=
'cuda'
)
self
.
_fp16_g
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float16
,
device
=
'cuda'
)
self
.
_individual_flat_grads
=
[]
for
p_i
,
(
grads_info
,
p
)
in
enumerate
(
zip
(
self
.
_grads_info
,
self
.
_model_params
)):
self
.
_individual_flat_grads
.
append
(
self
.
_flat_grads
[
grads_info
[
"param_offset"
]:
grads_info
[
"param_offset"
]
+
grads_info
[
"param_grads_size"
]].
view_as
(
p
))
def
_flat_split
(
p
):
def
__blockify
(
p
):
return
[
p
[
block_id
*
self
.
_block_size
:(
block_id
+
1
)
*
self
.
_block_size
]
for
block_id
in
range
(
self
.
_num_blocks
)]
def
__shardify
(
p
):
return
[
p
[
shard_id
*
self
.
_shard_size
:(
shard_id
+
1
)
*
self
.
_shard_size
]
for
shard_id
in
range
(
self
.
_group_size
)]
def
__chunkify
(
p
):
return
[
p
[
chunk_id
*
self
.
_chunk_size
:(
chunk_id
+
1
)
*
self
.
_chunk_size
]
for
chunk_id
in
range
(
self
.
_group_size
)]
list_of_blocks
=
__blockify
(
self
.
_flat_grads
)
list_of_list_of_shards
=
[
__shardify
(
block
)
for
block
in
list_of_blocks
]
list_of_list_of_list_of_chunks
=
[[
__chunkify
(
shard
)
for
shard
in
shards
]
for
shards
in
list_of_list_of_shards
]
return
list_of_blocks
,
list_of_list_of_shards
,
list_of_list_of_list_of_chunks
self
.
_flat_grads_blocks
,
self
.
_flat_grads_shards
,
self
.
_flat_grads_chunks
=
_flat_split
(
self
.
_flat_grads
)
def
_full_packed_split
(
p
):
def
__shardify
(
p
):
return
[
p
[
mega_shard
*
self
.
_mega_shard_size
:(
mega_shard
+
1
)
*
self
.
_mega_shard_size
]
for
mega_shard
in
range
(
self
.
_group_size
)]
def
__blockify
(
p
):
return
[
p
[
block_id
*
self
.
_num_chunks
*
self
.
_chunk_size
:(
block_id
+
1
)
*
self
.
_num_chunks
*
self
.
_chunk_size
]
for
block_id
in
range
(
self
.
_num_blocks
)]
def
__chunkify
(
p
):
return
[
p
[
chunk_id
*
self
.
_chunk_size
:(
chunk_id
+
1
)
*
self
.
_chunk_size
]
for
chunk_id
in
range
(
self
.
_num_chunks
)]
list_of_mega_shards
=
__shardify
(
p
)
list_of_list_of_mega_blocks
=
[
__blockify
(
mega_shard
)
for
mega_shard
in
list_of_mega_shards
]
list_of_list_of_list_of_mega_chunks
=
[[
__chunkify
(
mega_block
)
for
mega_block
in
mega_blocks
]
for
mega_blocks
in
list_of_list_of_mega_blocks
]
return
list_of_mega_shards
,
list_of_list_of_mega_blocks
,
list_of_list_of_list_of_mega_chunks
self
.
_new_params_mega_shards
,
self
.
_new_params_mega_blocks
,
self
.
_new_params_mega_chunks
=
_full_packed_split
(
self
.
_new_params
)
def
_packed_split
(
p
):
def
__packed_blockify
(
p
):
packed_block_size
=
self
.
_num_chunks
*
self
.
_chunk_size
return
[
p
[
block_id
*
packed_block_size
:(
block_id
+
1
)
*
packed_block_size
]
for
block_id
in
range
(
self
.
_num_blocks
)]
def
__packed_chunkify
(
p
):
return
[
p
[
chunk_id
*
self
.
_chunk_size
:(
chunk_id
+
1
)
*
self
.
_chunk_size
]
for
chunk_id
in
range
(
self
.
_num_chunks
)]
list_of_blocks
=
__packed_blockify
(
p
)
list_of_list_of_chunks
=
[
__packed_chunkify
(
block
)
for
block
in
list_of_blocks
]
return
list_of_blocks
,
list_of_list_of_chunks
self
.
_fp32_p_blocks
,
self
.
_fp32_p_chunks
=
_packed_split
(
self
.
_fp32_p
)
self
.
_fp32_m_blocks
,
self
.
_fp32_m_chunks
=
_packed_split
(
self
.
_fp32_m
)
self
.
_fp32_v_blocks
,
self
.
_fp32_v_chunks
=
_packed_split
(
self
.
_fp32_v
)
self
.
_fp16_p_blocks
,
self
.
_fp16_p_chunks
=
_packed_split
(
self
.
_fp16_p
)
self
.
_fp16_g_blocks
,
self
.
_fp16_g_chunks
=
_packed_split
(
self
.
_fp16_g
)
# current arrangement
#
# self._flat_grads
# self._flat_grads_blocks [x self._num_blocks, self._block_size]
# self._flat_grads_chunks [x self._num_chunks, self._chunk_size]
# self._flat_grads_shards [x self._group_size, self._shard_size]
#
# self._new_params
# self._new_params_mega_shards [x self._group_size, self._num_blocks*self._num_chunks*self._shard_size]
# self._new_params_mega_blocks [x self._num_blocks, self._num_chunks*self._shard_size]
# self._new_params_mega_chunks [x self._num_chunks, self._shard_size]
#
# self._fp32_p
# self._fp32_p_blocks [x self._num_blocks, self._num_chunks*self._shard_size]
# self._fp32_p_chunks [x self._num_chunks, self._shard_size]
# each chunk contains one shard
# same for self._fp32_m, self._fp32_v, self._fp16_p and self._fp16_g
#
# Usage:
#
# for chunk_id in range(self._num_chunks):
# works[chunk_id] = torch.distributed.reduce_scatter(self._flat_grads_chunks[block_id][chunk_id], self._fp16_g_chunks[block_id][chunk_id], ...)
#
# ----------------------------------------------------------------------------------------
#
# new arrangement
#
# NB! New equations for self._shard_size and self._chunk_size
#
# self._flat_grads
# self._flat_grads_blocks [x self._num_blocks, self._block_size]
# self._flat_grads_shards [x self._group_size, self._shard_size]
# self._flat_grads_chunks [x self._num_chunks, self._chunk_size]
#
# self._new_params
# self._new_params_mega_shards [x self._group_size, self._num_blocks*self._num_chunks*self._chunk_size]
# self._new_params_mega_blocks [x self._num_blocks, self._num_chunks*self._chunk_size]
# self._new_params_mega_chunks [x self._num_chunks, self._chunk_size]
#
# self._fp32_p
# self._fp32_p_blocks [x self._num_blocks, self._num_chunks*self._chunk_size]
# self._fp32_p_chunks [x self._num_chunks, self._chunk_size]
# same for self._fp32_m, self._fp32_v, self._fp16_p and self._fp16_g
#
# Usage:
#
# work = torch.distributed.reduce_scatter(self._flat_grads_blocks[block_id], self._fp16_g[block_id], ...)
# for chunk_id in range(self._num_chunks):
# work.wait()
# works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id], ...)
# or
# work.wait()
# works[0] = torch.distributed.all_reduce(self._fp16_g_blocks[block_id], ...)
#
# This paragraph does two things:
# 1) Copy model parameters into master buffer
# 2) Create tensor lists for unpacking new parameter tensor after all-gather
self
.
_packed_flat_to_model_params
=
[]
for
shard_id
in
range
(
self
.
_group_size
):
for
block_id
in
range
(
self
.
_num_blocks
):
flat_shard_start
=
(
block_id
*
self
.
_group_size
+
shard_id
)
*
self
.
_shard_size
flat_shard_end
=
flat_shard_start
+
self
.
_shard_size
for
p
,
grads_info
in
zip
(
self
.
_model_params
,
self
.
_grads_info
):
flat_grad_start
=
grads_info
[
"param_offset"
]
flat_grad_end
=
flat_grad_start
+
grads_info
[
"param_grads_size"
]
clipped_start
=
(
lambda
a
,
b
:
a
if
a
>
b
else
b
)(
flat_grad_start
,
flat_shard_start
)
clipped_end
=
(
lambda
a
,
b
:
a
if
a
<
b
else
b
)(
flat_grad_end
,
flat_shard_end
)
if
clipped_start
<
clipped_end
:
grad_offset
=
clipped_start
-
flat_grad_start
grad_length
=
clipped_end
-
clipped_start
shard_offset
=
clipped_start
-
flat_shard_start
model_param_fragment
=
p
.
view
(
-
1
)[
grad_offset
:
grad_offset
+
grad_length
]
new_param_packed_fragment
=
self
.
_new_params_mega_blocks
[
shard_id
][
block_id
][
shard_offset
:
shard_offset
+
grad_length
]
self
.
_packed_flat_to_model_params
.
append
(
(
new_param_packed_fragment
,
model_param_fragment
)
)
if
shard_id
==
self
.
_rank_in_group
:
# copy model parameters into master buffer
master_param_fragment
=
self
.
_fp32_p_blocks
[
block_id
][
shard_offset
:
shard_offset
+
grad_length
]
print
(
"model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s"
%
(
str
(
model_param_fragment
.
size
()),
str
(
new_param_packed_fragment
.
size
()),
str
(
master_param_fragment
.
size
())))
master_param_fragment
.
copy_
(
model_param_fragment
)
p_in
,
p_out
=
zip
(
*
self
.
_packed_flat_to_model_params
)
self
.
_packed_flat_to_model_params
=
[
p_in
,
p_out
]
self
.
_distributed_weight_update
=
distributed_weight_update
# Is this still needed?
self
.
_num_rs_pg
=
dwu_num_rs_pg
self
.
_num_ar_pg
=
dwu_num_ar_pg
self
.
_num_ag_pg
=
dwu_num_ag_pg
if
self
.
_num_groups
>
1
:
self
.
_num_ar_pg
=
dwu_num_ar_pg
self
.
_ar_pg
=
[]
for
dev_i
in
range
(
self
.
_group_size
):
ranks
=
[
dev_i
+
j
*
self
.
_group_size
for
j
in
range
(
self
.
_num_groups
)]
...
...
@@ -136,10 +303,9 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_ar_pg
.
append
(
grp
)
self
.
_ar_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_ar_pg
)]
for
ar_pg
in
self
.
_ar_pg
:
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
ar_pg
)
self
.
_num_rs_pg
=
dwu_num_rs_pg
rs_ranks
=
[]
for
group_i
in
range
(
self
.
_num_groups
):
rs_ranks
.
append
([
group_i
*
self
.
_group_size
+
j
for
j
in
range
(
self
.
_group_size
)])
...
...
@@ -150,22 +316,39 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_rs_pg
.
append
(
grp
)
if
self
.
_compute_L2_grad_norm
and
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_l2_grad_norm_pg
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
self
.
_l2_grad_norm_pg
)
self
.
_rs_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_rs_pg
)]
for
rs_pg
in
self
.
_rs_pg
:
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
rs_pg
)
self
.
_redux_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_group_size
)]
self
.
_compute_L2_grad_norm
=
compute_L2_grad_norm
if
self
.
_compute_L2_grad_norm
:
self
.
_L2_grad_norm
=
torch
.
zeros
(
size
=
[
1
],
dtype
=
torch
.
float32
).
cuda
()
self
.
_l2_grad_norm_st
=
torch
.
cuda
.
Stream
()
if
self
.
_num_ag_pg
==
0
:
self
.
_ag_pg
=
self
.
_rs_pg
self
.
_ag_st
=
self
.
_rs_st
self
.
_num_ag_pg
=
self
.
_num_rs_pg
else
:
self
.
_ag_pg
=
[]
for
group_i
in
range
(
self
.
_num_groups
):
ranks
=
rs_ranks
[
group_i
]
for
i
in
range
(
self
.
_num_ag_pg
):
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_ag_pg
.
append
(
grp
)
self
.
_ag_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_ag_pg
)]
for
ag_pg
in
self
.
_ag_pg
:
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
ag_pg
)
self
.
_l2_grad_norm_st
=
torch
.
cuda
.
Stream
()
if
self
.
_compute_L2_grad_norm
else
None
self
.
_completion_st
=
torch
.
cuda
.
Stream
()
self
.
_last_step
=
False
self
.
_reductions_works
=
[
None
]
*
self
.
_num_blocks
self
.
_allgather_works
=
[
None
]
*
self
.
_num_blocks
import
inspect
assert
(
'no_copy'
in
inspect
.
getfullargspec
(
torch
.
distributed
.
reduce_scatter
).
args
),
"This version of c10d does not support no_copy option"
def
set_last_step
(
self
,
last_step
):
self
.
_last_step
=
last_step
def
_get_flush_block
(
self
):
flush_block
=
[]
...
...
@@ -187,77 +370,112 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
return
flush_block
def
_pipeline_block_reductions
(
self
,
block_id
):
self
.
_flatten_grad_mt
(
1.0
/
self
.
_world_size
if
self
.
_predivide
else
1.0
)
start
=
block_id
*
self
.
_block_size
end
=
start
+
self
.
_block_size
grad_block
=
self
.
_flat_grads
[
start
:
end
]
active_rank
=
self
.
_group_id
*
self
.
_group_size
+
block_id
redux_stream
=
self
.
_redux_st
[
block_id
]
redux_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
redux_stream
):
work
=
torch
.
distributed
.
reduce
(
grad_block
,
active_rank
,
group
=
self
.
_rs_pg
[
block_id
%
self
.
_num_rs_pg
],
async_op
=
True
)
if
self
.
_num_groups
>
1
and
self
.
_rank
==
active_rank
:
work
.
wait
()
work
=
torch
.
distributed
.
all_reduce
(
grad_block
,
group
=
self
.
_ar_pg
[
block_id
%
self
.
_num_ar_pg
],
async_op
=
True
)
if
self
.
_compute_L2_grad_norm
:
if
self
.
_rank
==
active_rank
:
with
torch
.
cuda
.
stream
(
self
.
_l2_grad_norm_st
):
work
.
wait
()
self
.
_L2_grad_norm
=
grad_block
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
if
block_id
==
0
:
with
torch
.
cuda
.
stream
(
self
.
_l2_grad_norm_st
):
torch
.
distributed
.
all_reduce
(
self
.
_L2_grad_norm
,
group
=
self
.
_rs_pg
[
self
.
_num_rs_pg
-
1
])
self
.
_L2_grad_norm
.
sqrt_
()
# FIXME: Does completion stream need to wait for L2 grad norm to finish?
self
.
_completion_st
.
wait_stream
(
self
.
_l2_grad_norm_st
)
with
torch
.
cuda
.
stream
(
redux_stream
):
work
.
wait
()
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works
=
[
None
]
*
self
.
_num_chunks
rs_stream
=
self
.
_rs_st
[
block_id
%
self
.
_num_rs_pg
]
rs_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
rs_stream
):
rs_work
=
torch
.
distributed
.
reduce_scatter
(
self
.
_fp16_g_blocks
[
block_id
],
self
.
_flat_grads_shards
[
block_id
],
group
=
self
.
_rs_pg
[
block_id
%
self
.
_num_rs_pg
],
async_op
=
True
,
no_copy
=
True
)
for
chunk_id
in
range
(
self
.
_num_chunks
):
works
[
chunk_id
]
=
rs_work
# Reduction across nodes for each rank
if
self
.
_num_groups
>
1
:
for
chunk_id
in
range
(
self
.
_num_chunks
):
glob_chunk_id
=
block_id
*
self
.
_num_chunks
+
chunk_id
ar_stream
=
self
.
_ar_st
[
glob_chunk_id
%
self
.
_num_ar_pg
]
with
torch
.
cuda
.
stream
(
ar_stream
):
rs_work
.
wait
()
works
[
chunk_id
]
=
torch
.
distributed
.
all_reduce
(
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
],
group
=
self
.
_ar_pg
[
glob_chunk_id
%
self
.
_num_ar_pg
],
async_op
=
True
)
self
.
_reductions_works
[
block_id
]
=
works
# Optionally compute L2 grad norm
if
self
.
_compute_L2_grad_norm
and
block_id
==
0
:
with
torch
.
cuda
.
stream
(
self
.
_l2_grad_norm_st
):
for
block_id
in
range
(
self
.
_num_blocks
):
for
chunk_id
in
range
(
self
.
_num_chunks
):
self
.
_reductions_works
[
block_id
][
chunk_id
].
wait
()
# Since the packed format is contiguous after reductions, only one norm is needed
l2_grad_norm_sq
=
torch
.
empty
([
1
],
device
=
'cuda'
)
l2_grad_norm_sq
=
self
.
_fp16_g
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
torch
.
distributed
.
all_reduce
(
l2_grad_norm_sq
,
group
=
self
.
_l2_grad_norm_pg
)
self
.
_L2_grad_norm
=
l2_grad_norm_sq
.
sqrt
()
def
__launch_step_kernel
(
self
,
p
,
p_copy
,
m
,
v
,
g
):
combined_scale
=
self
.
_global_scale
if
self
.
_param_group
[
'max_grad_norm'
]
>
0
and
math
.
isfinite
(
self
.
L2_grad_norm
):
combined_scale
=
self
.
_param_group
[
'max_grad_norm'
]
/
(
self
.
L2_grad_norm
/
self
.
_global_scale
+
1e-6
)
combined_scale
=
self
.
_global_scale
/
min
(
1
,
combined_scale
)
bias_correction
=
1
if
self
.
_param_group
[
'bias_correction'
]
else
0
beta1
,
beta2
=
self
.
_param_group
[
'betas'
]
fused_adam_cuda
.
adam
(
p
,
p_copy
,
m
,
v
,
g
,
self
.
_param_group
[
'lr'
],
beta1
,
beta2
,
self
.
_param_group
[
'eps'
],
combined_scale
,
self
.
_param_state
[
'step'
]
+
1
,
self
.
eps_mode
,
bias_correction
,
self
.
_param_group
[
'weight_decay'
])
def
_pipeline_block_step
(
self
,
block_id
):
active_rank
=
self
.
_group_id
*
self
.
_group_size
+
block_id
if
self
.
_rank
==
active_rank
:
redux_stream
=
self
.
_redux_st
[
block_id
]
with
torch
.
cuda
.
stream
(
redux_stream
):
self
.
_partial_step_single_shard
(
block_id
)
# Call step kernel once per block
ag_stream
=
self
.
_ag_st
[
block_id
%
self
.
_num_ag_pg
]
with
torch
.
cuda
.
stream
(
ag_stream
):
for
chunk_id
in
range
(
self
.
_num_chunks
):
self
.
_reductions_works
[
block_id
][
chunk_id
].
wait
()
self
.
__launch_step_kernel
(
self
.
_fp32_p_blocks
[
block_id
],
self
.
_fp16_p_blocks
[
block_id
],
self
.
_fp32_m_blocks
[
block_id
],
self
.
_fp32_v_blocks
[
block_id
],
self
.
_fp16_g_blocks
[
block_id
])
# Call all-gather once per step.
# FIXME: Determine which is faster, one all-gather per block or a single all-gather at end
if
block_id
==
0
:
new_params_blocks
=
[
self
.
_new_params
[
block
*
self
.
_block_size
:(
block
+
1
)
*
self
.
_block_size
]
for
block
in
range
(
self
.
_group_size
)]
for
redux_stream
in
self
.
_redux_st
:
self
.
_completion_st
.
wait_stream
(
redux_stream
)
for
other_ag_stream
in
self
.
_ag_st
:
self
.
_completion_st
.
wait_stream
(
other_ag_stream
)
with
torch
.
cuda
.
stream
(
self
.
_completion_st
):
torch
.
distributed
.
all_gather
(
new_params_
blocks
,
new_params_blocks
[
self
.
_rank_in_group
],
group
=
self
.
_
rs
_pg
[
self
.
_num_rs_pg
-
1
],
no_copy
=
True
)
torch
.
distributed
.
all_gather
(
self
.
_
new_params_
mega_shards
,
self
.
_fp16_p
,
group
=
self
.
_
ag
_pg
[
0
],
no_copy
=
True
)
def
_pipeline_step
(
self
):
# Call step kernel once per step
# Call all-gather once per step
with
torch
.
cuda
.
stream
(
self
.
_completion_st
):
for
block_id
in
range
(
self
.
_num_blocks
):
for
chunk_id
in
range
(
self
.
_num_chunks
):
self
.
_reductions_works
[
block_id
][
chunk_id
].
wait
()
self
.
__launch_step_kernel
(
self
.
_fp32_p
,
self
.
_fp16_p
,
self
.
_fp32_m
,
self
.
_fp32_v
,
self
.
_fp16_g
)
torch
.
distributed
.
all_gather
(
self
.
_new_params_mega_shards
,
self
.
_fp16_p
,
group
=
self
.
_ag_pg
[
0
],
no_copy
=
True
)
def
_flatten_grad_mt
(
self
,
scale
):
grads
=
[]
flat_grads
=
[]
for
p_i
,
(
grads_info
,
grad
)
in
enumerate
(
zip
(
self
.
_grads_info
,
self
.
_grads
)):
if
grad
is
not
None
:
grads
.
append
(
grad
)
flat_grads
.
append
(
self
.
_flat_grads
[
grads_info
[
"param_offset"
]:
grads_info
[
"param_offset"
]
+
grads_info
[
"param_grads_size"
]]
)
self
.
_grads
=
[
None
]
*
len
(
self
.
_grads_info
)
if
len
(
grads
)
>
0
:
if
self
.
_flat_mt
and
len
(
self
.
_grads
)
>
0
:
self
.
_overflow_buf
.
zero_
()
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
self
.
_overflow_buf
,
[
grads
,
flat
_grads
]
,
list
(
zip
(
*
self
.
_grads
))
,
scale
)
self
.
_grads
=
[]
def
_do_overlapped_reduction
(
self
,
param_i
,
param_grads_size
,
param_offset
,
grad
):
def
_do_overlapped_reduction
(
self
,
param_i
,
param_grads_size
,
param_offset
,
param
):
# handle overlapped reductions
self
.
_grads
[
param_i
]
=
grad
.
view
(
-
1
)
if
self
.
_flat_mt
:
self
.
_grads
.
append
(
(
param
.
grad
,
self
.
_individual_flat_grads
[
param_i
])
)
else
:
torch
.
div
(
param
.
grad
,
self
.
_world_size
if
self
.
_predivide
else
1.0
,
out
=
self
.
_individual_flat_grads
[
param_i
])
self
.
_grads_generated
[
param_i
]
=
True
if
not
self
.
_last_step
:
if
self
.
_overlap_reductions
:
...
...
@@ -315,130 +533,6 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
else
:
return
None
# Distributed weight update algorithm:
# Model parameters are kept as-is.
# Gradients are flattened during backprop.
# Reductions are done with an intra-node reduce-scatter followed by an inter-node all-reduce.
# Step function is sharded and the shards are assembled with an intra-node all-gather.
# Sharded step function needs internal fp32 buffers for p, m and v.
# To save memory, we allocate the fp32 buffers to cover only the shards local GPU will update.
# This means we have to play around with indexes, which requires knowledge of block and shard number.
# Implement a method that performs a partial update of a single shard within a single block.
def
_partial_step_single_shard
(
self
,
block_id
,
undo
=
False
):
"""Perform step function for a single shard.
Arguments:
block_id (integer): Block index of shard [0,self._group_size>
undo (boolean, optional): If True, undo effect of previously called partial step.
"""
block_start
=
block_id
*
self
.
_block_size
block_end
=
block_start
+
self
.
_block_size
if
self
.
_fp32_p
is
None
:
assert
(
not
undo
),
"Tried to undo step before calling step."
# Allocate fp32 buffers on demand. Note that we don't make these part of the state
# since each rank only has partial buffers.
# To-Do:
self
.
_fp32_p
=
torch
.
zeros
([
self
.
_block_size
]).
float
().
cuda
()
self
.
_fp32_m
=
torch
.
zeros
([
self
.
_block_size
]).
float
().
cuda
()
self
.
_fp32_v
=
torch
.
zeros
([
self
.
_block_size
]).
float
().
cuda
()
self
.
_copy_to_fp32
=
True
step
=
None
param_i
=
0
for
group
in
self
.
param_groups
:
# compute combined scale factor for this group
combined_scale
=
self
.
_global_scale
if
group
[
'max_grad_norm'
]
>
0
and
math
.
isfinite
(
self
.
L2_grad_norm
):
combined_scale
=
group
[
'max_grad_norm'
]
/
(
self
.
L2_grad_norm
/
self
.
_global_scale
+
1e-6
)
combined_scale
=
self
.
_global_scale
/
min
(
1
,
combined_scale
)
bias_correction
=
1
if
group
[
'bias_correction'
]
else
0
group_start
=
-
1
group_end
=
-
2
for
p
in
group
[
'params'
]:
if
not
p
.
requires_grad
:
continue
#if p.grad.is_sparse:
# raise RuntimeError('FusedAdam does not support sparse gradients, please consider SparseAdam instead')
state
=
self
.
state
[
p
]
if
len
(
state
)
==
0
:
state
[
'step'
]
=
0
if
step
is
None
:
# all we want from state at this point is state['step'], which should be the same for all p
step
=
state
[
'step'
]
nels
=
p
.
numel
()
offset
=
self
.
_grads_info
[
param_i
][
'param_offset'
]
param_i
+=
1
start
=
offset
end
=
start
+
nels
clipped_start
=
start
if
start
>=
block_start
else
block_start
clipped_end
=
end
if
end
<=
block_end
else
block_end
# check if this parameter contributes to block
if
clipped_start
<
clipped_end
:
if
group_start
<
0
:
group_start
=
clipped_start
group_end
=
clipped_end
if
self
.
_copy_to_fp32
:
param_offset
=
clipped_start
-
block_start
param_size
=
clipped_end
-
clipped_start
buffer_start
=
param_offset
buffer_end
=
buffer_start
+
param_size
param_start
=
(
clipped_start
-
start
)
param_end
=
param_start
+
param_size
#assert (buffer_start >= 0 and buffer_end <= self._fp32_p.numel() and param_start >= 0 and param_end <= p.numel()), "Illegal copy"
self
.
_fp32_p
[
buffer_start
:
buffer_end
].
copy_
(
p
.
view
(
-
1
)[
param_start
:
param_end
].
float
())
group_size
=
group_end
-
group_start
if
group_size
>
0
:
assert
(
step
is
not
None
),
"state['step'] is None for this parameter group"
group_offset
=
group_start
-
block_start
group_block_start
=
block_start
+
group_offset
group_block_end
=
group_block_start
+
group_size
group_buffer_start
=
group_offset
group_buffer_end
=
group_buffer_start
+
group_size
beta1
,
beta2
=
group
[
'betas'
]
if
undo
:
fused_adam_cuda
.
maybe_adam_undo
(
torch
.
empty
([
0
]),
self
.
_fp32_p
[
group_buffer_start
:
group_buffer_end
],
self
.
_fp32_m
[
group_buffer_start
:
group_buffer_end
],
self
.
_fp32_v
[
group_buffer_start
:
group_buffer_end
],
self
.
_flat_grads
[
group_block_start
:
group_block_end
],
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
combined_scale
,
step
+
1
,
# FIXME: Verify this should be step+1
self
.
eps_mode
,
bias_correction
,
group
[
'weight_decay'
])
else
:
fused_adam_cuda
.
adam
(
self
.
_fp32_p
[
group_buffer_start
:
group_buffer_end
],
self
.
_new_params
[
group_block_start
:
group_block_end
],
self
.
_fp32_m
[
group_buffer_start
:
group_buffer_end
],
self
.
_fp32_v
[
group_buffer_start
:
group_buffer_end
],
self
.
_flat_grads
[
group_block_start
:
group_block_end
],
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
combined_scale
,
step
+
1
,
self
.
eps_mode
,
bias_correction
,
group
[
'weight_decay'
])
def
complete_reductions
(
self
):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
...
...
@@ -455,61 +549,73 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
if
self
.
_last_step
or
not
self
.
_overlap_reductions
:
# nothing done so far, run full pipeline after reductions
for
block_id
in
range
(
self
.
_
group_size
-
1
,
-
1
,
-
1
):
for
block_id
in
range
(
self
.
_
num_blocks
-
1
,
-
1
,
-
1
):
self
.
_pipeline_block_reductions
(
block_id
)
self
.
_copy_to_fp32
=
False
self
.
_current_block
=
self
.
_group_size
if
self
.
_compute_L2_grad_norm
:
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_l2_grad_norm_st
)
self
.
_current_block
=
self
.
_num_blocks
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
def
revert_step
(
self
):
"""Revert effect of previously calling partial_step.
"""
self
.
_partial_step_single_shard
(
self
.
_rank_in_group
,
undo
=
True
)
def
step
(
self
,
closure
=
None
):
# Call undo kernel once per step
combined_scale
=
self
.
_global_scale
if
self
.
_param_group
[
'max_grad_norm'
]
>
0
and
math
.
isfinite
(
self
.
L2_grad_norm
):
combined_scale
=
self
.
_param_group
[
'max_grad_norm'
]
/
(
self
.
L2_grad_norm
/
self
.
_global_scale
+
1e-6
)
combined_scale
=
self
.
_global_scale
/
min
(
1
,
combined_scale
)
bias_correction
=
1
if
self
.
_param_group
[
'bias_correction'
]
else
0
beta1
,
beta2
=
self
.
_param_group
[
'betas'
]
fused_adam_cuda
.
maybe_adam_undo
(
torch
.
empty
([
0
]),
self
.
_fp32_p
,
self
.
_fp32_m
,
self
.
_fp32_v
,
self
.
_fp16_g
,
self
.
_param_group
[
'lr'
],
beta1
,
beta2
,
self
.
_param_group
[
'eps'
],
combined_scale
,
self
.
_param_state
[
'step'
]
+
1
,
self
.
eps_mode
,
bias_correction
,
self
.
_param_group
[
'weight_decay'
])
def
step
(
self
,
closure
=
None
,
skip_overflow_check
=
False
):
loss
=
None
if
closure
is
not
None
:
loss
=
closure
()
if
self
.
_last_step
or
not
self
.
_overlap_reductions
or
not
self
.
_full_pipeline
:
for
block_id
in
range
(
self
.
_group_size
-
1
,
-
1
,
-
1
):
self
.
_pipeline_block_step
(
block_id
)
self
.
_pipeline_step
()
with
torch
.
cuda
.
stream
(
self
.
_completion_st
):
# Check for overflow
# Store state for loss scaler calculation
self
.
strided_check_finite
(
self
.
_new_params
,
stride
=
self
.
_block_size
,
start
=
0
,
end
=
self
.
_net_total_param_size
)
has_overflow
=
self
.
peek_overflow
if
skip_overflow_check
:
has_overflow
=
False
else
:
self
.
strided_check_finite
(
self
.
_new_params
,
stride
=
self
.
_shard_size
,
start
=
0
,
end
=
self
.
_net_total_param_size
)
has_overflow
=
self
.
peek_overflow
if
has_overflow
:
print
(
"Reverting step"
)
self
.
revert_step
()
else
:
# Copy self._new_params to model params
p_in
=
[]
p_out
=
[]
with
torch
.
no_grad
():
param_i
=
0
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
if
not
p
.
requires_grad
:
continue
state
=
self
.
state
[
p
]
if
len
(
state
)
==
0
:
state
[
'step'
]
=
0
state
[
'step'
]
+=
1
nels
=
p
.
numel
()
offset
=
self
.
_grads_info
[
param_i
][
'param_offset'
]
p_in
.
append
(
self
.
_new_params
[
offset
:
offset
+
nels
].
view_as
(
p
))
p_out
.
append
(
p
)
param_i
+=
1
multi_tensor_applier
(
fused_adam_cuda
.
maybe_cast_mt
,
self
.
_overflow_buf
,
[
p_in
,
p_out
]);
for
p
in
self
.
_model_params
:
self
.
state
[
p
][
'step'
]
+=
1
multi_tensor_applier
(
fused_adam_cuda
.
maybe_cast_mt
,
self
.
_overflow_buf
,
self
.
_packed_flat_to_model_params
)
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_completion_st
)
self
.
_reductions_works
=
[
None
]
*
self
.
_num_blocks
self
.
_allgather_works
=
[
None
]
*
self
.
_num_blocks
return
loss
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment