Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
208c91e0
"docs/source/vscode:/vscode.git/clone" did not exist on "4bd4d6e3489bcc9c95f63eff71c0ab9aa5e1e829"
Commit
208c91e0
authored
Apr 14, 2020
by
Thor Johnsen
Browse files
internal pipelining more similar to micro-benchmarks
parent
7ba6a038
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
119 additions
and
125 deletions
+119
-125
apex/contrib/optimizers/distributed_fused_adam.py
apex/contrib/optimizers/distributed_fused_adam.py
+119
-125
No files found.
apex/contrib/optimizers/distributed_fused_adam.py
View file @
208c91e0
import
math
import
math
import
torch
import
torch
import
importlib
import
importlib
import
amp_C
from
apex.multi_tensor_apply
import
multi_tensor_applier
from
apex.multi_tensor_apply
import
multi_tensor_applier
class
DistributedFusedAdam
(
torch
.
optim
.
Optimizer
):
class
DistributedFusedAdam
(
torch
.
optim
.
Optimizer
):
...
@@ -46,9 +47,9 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -46,9 +47,9 @@ class DistributedFusedAdam(torch.optim.Optimizer):
amp_scale_adjustment
=
1.0
,
overlap_reductions
=
True
,
full_pipeline
=
True
,
amp_scale_adjustment
=
1.0
,
overlap_reductions
=
True
,
full_pipeline
=
True
,
compute_L2_grad_norm
=
False
,
distributed_weight_update
=
0
,
compute_L2_grad_norm
=
False
,
distributed_weight_update
=
0
,
dwu_group_size
=
0
,
dwu_num_blocks
=
4
,
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_group_size
=
0
,
dwu_num_blocks
=
4
,
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_num_ag_pg
=
0
,
dwu_num_blk_st
=
1
,
revert_method
=
1
,
flat_mt
=
False
,
dwu_num_ag_pg
=
0
,
revert_method
=
1
,
flat_mt
=
False
,
dwu_num_chunks
=
4
,
predivide
=
True
,
internal_pipeline
=
False
,
dwu_num_chunks
=
4
,
predivide
=
True
,
e5m2_allgather
=
False
,
e5m2_allgather
=
False
):
do_not_flatten_model
=
False
):
global
fused_adam_cuda
global
fused_adam_cuda
fused_adam_cuda
=
importlib
.
import_module
(
"fused_adam_cuda"
)
fused_adam_cuda
=
importlib
.
import_module
(
"fused_adam_cuda"
)
...
@@ -67,6 +68,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -67,6 +68,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
assert
(
not
flat_mt
),
"flat_mt option is not safe in this version"
# Way to revert a step
# Way to revert a step
# 3 -> undo kernel + double buffer (debug, print norm of difference)
# 3 -> undo kernel + double buffer (debug, print norm of difference)
# 2 -> double buffer fp32 parameters
# 2 -> double buffer fp32 parameters
...
@@ -81,8 +84,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -81,8 +84,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
_num_blocks
=
dwu_num_blocks
self
.
_num_blocks
=
dwu_num_blocks
self
.
_num_chunks
=
dwu_num_chunks
self
.
_num_chunks
=
dwu_num_chunks
self
.
_predivide
=
predivide
self
.
_predivide
=
predivide
self
.
_internal_pipeline
=
internal_pipeline
self
.
_e5m2_allgather
=
e5m2_allgather
self
.
_e5m2_allgather
=
e5m2_allgather
self
.
_do_not_flatten_model
=
do_not_flatten_model
self
.
_full_pipeline
=
full_pipeline
self
.
_full_pipeline
=
full_pipeline
self
.
_compute_L2_grad_norm
=
compute_L2_grad_norm
self
.
_compute_L2_grad_norm
=
compute_L2_grad_norm
self
.
_L2_grad_norm
=
torch
.
zeros
([]).
cuda
()
if
self
.
_compute_L2_grad_norm
else
None
self
.
_L2_grad_norm
=
torch
.
zeros
([]).
cuda
()
if
self
.
_compute_L2_grad_norm
else
None
...
@@ -118,11 +121,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -118,11 +121,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
_net_total_param_size
=
p_offset
self
.
_net_total_param_size
=
p_offset
self
.
_total_param_size
=
p_offset
self
.
_total_param_size
=
p_offset
dwu_min_page_size
=
256
*
self
.
_num_blocks
*
self
.
_group_size
dwu_min_page_size
=
256
*
self
.
_num_blocks
*
self
.
_num_chunks
*
self
.
_group_size
self
.
_total_param_size
=
((
self
.
_total_param_size
+
dwu_min_page_size
-
1
)
//
dwu_min_page_size
)
*
dwu_min_page_size
self
.
_total_param_size
=
((
self
.
_total_param_size
+
dwu_min_page_size
-
1
)
//
dwu_min_page_size
)
*
dwu_min_page_size
self
.
_block_size
=
self
.
_total_param_size
//
self
.
_num_blocks
self
.
_block_size
=
self
.
_total_param_size
//
self
.
_num_blocks
self
.
_shard_size
=
self
.
_block_size
//
self
.
_group_size
self
.
_chunk_size
=
self
.
_block_size
//
self
.
_num_chunks
print
(
"self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._shard_size=%d"
%
(
self
.
_net_total_param_size
,
self
.
_total_param_size
,
dwu_min_page_size
,
self
.
_block_size
,
self
.
_shard_size
))
self
.
_shard_size
=
self
.
_chunk_size
//
self
.
_group_size
print
(
"self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d"
%
(
self
.
_net_total_param_size
,
self
.
_total_param_size
,
dwu_min_page_size
,
self
.
_block_size
,
self
.
_chunk_size
,
self
.
_shard_size
))
self
.
_low_param_i
=
[
0
]
*
self
.
_num_blocks
self
.
_low_param_i
=
[
0
]
*
self
.
_num_blocks
for
block_id
in
range
(
self
.
_num_blocks
-
1
,
-
1
,
-
1
):
for
block_id
in
range
(
self
.
_num_blocks
-
1
,
-
1
,
-
1
):
...
@@ -143,7 +147,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -143,7 +147,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
_num_rs_pg
=
dwu_num_rs_pg
self
.
_num_rs_pg
=
dwu_num_rs_pg
self
.
_num_ar_pg
=
dwu_num_ar_pg
self
.
_num_ar_pg
=
dwu_num_ar_pg
self
.
_num_ag_pg
=
dwu_num_ag_pg
self
.
_num_ag_pg
=
dwu_num_ag_pg
self
.
_num_blk_st
=
dwu_num_blk_st
if
self
.
_num_groups
>
1
:
if
self
.
_num_groups
>
1
:
self
.
_ar_pg
=
[]
self
.
_ar_pg
=
[]
for
dev_i
in
range
(
self
.
_group_size
):
for
dev_i
in
range
(
self
.
_group_size
):
...
@@ -152,6 +155,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -152,6 +155,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_ar_pg
.
append
(
grp
)
self
.
_ar_pg
.
append
(
grp
)
self
.
_ar_st
=
[
torch
.
cuda
.
Stream
()]
*
self
.
_num_ar_pg
rs_ranks
=
[]
rs_ranks
=
[]
for
group_i
in
range
(
self
.
_num_groups
):
for
group_i
in
range
(
self
.
_num_groups
):
rs_ranks
.
append
([
group_i
*
self
.
_group_size
+
j
for
j
in
range
(
self
.
_group_size
)])
rs_ranks
.
append
([
group_i
*
self
.
_group_size
+
j
for
j
in
range
(
self
.
_group_size
)])
...
@@ -162,8 +166,13 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -162,8 +166,13 @@ class DistributedFusedAdam(torch.optim.Optimizer):
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_rs_pg
.
append
(
grp
)
self
.
_rs_pg
.
append
(
grp
)
if
self
.
_compute_L2_grad_norm
and
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_l2_grad_norm_pg
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
self
.
_rs_st
=
[
torch
.
cuda
.
Stream
()]
*
self
.
_num_rs_pg
if
self
.
_num_ag_pg
==
0
:
if
self
.
_num_ag_pg
==
0
:
self
.
_ag_pg
=
self
.
_rs_pg
self
.
_ag_pg
=
self
.
_rs_pg
self
.
_ag_st
=
self
.
_rs_st
self
.
_num_ag_pg
=
self
.
_num_rs_pg
else
:
else
:
self
.
_ag_pg
=
[]
self
.
_ag_pg
=
[]
for
group_i
in
range
(
self
.
_num_groups
):
for
group_i
in
range
(
self
.
_num_groups
):
...
@@ -172,16 +181,15 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -172,16 +181,15 @@ class DistributedFusedAdam(torch.optim.Optimizer):
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_ag_pg
.
append
(
grp
)
self
.
_ag_pg
.
append
(
grp
)
self
.
_blk_st
=
[]
self
.
_ag_st
=
[
torch
.
cuda
.
Stream
()]
*
self
.
_num_ag_pg
for
i
in
range
(
self
.
_num_blk_st
):
self
.
_l2_grad_norm_st
=
torch
.
cuda
.
Stream
()
if
self
.
_compute_L2_grad_norm
else
None
self
.
_blk_st
.
append
(
torch
.
cuda
.
Stream
())
self
.
_completion_st
=
torch
.
cuda
.
Stream
()
self
.
_reductions_works
=
[
None
]
*
self
.
_num_blocks
self
.
_allgather_works
=
[
None
]
*
self
.
_num_blocks
import
inspect
import
inspect
if
'no_copy'
in
inspect
.
getfullargspec
(
torch
.
distributed
.
reduce_scatter
).
args
:
assert
(
'no_copy'
in
inspect
.
getfullargspec
(
torch
.
distributed
.
reduce_scatter
).
args
),
"This version of c10d does not support no_copy option"
self
.
_pg_supports_no_copy
=
True
else
:
self
.
_pg_supports_no_copy
=
False
print
(
"WARNING! torch.distributed.reduce_scatter does not support no_copy op."
)
def
set_last_step
(
self
,
last_step
):
def
set_last_step
(
self
,
last_step
):
...
@@ -207,71 +215,65 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -207,71 +215,65 @@ class DistributedFusedAdam(torch.optim.Optimizer):
return
flush_block
return
flush_block
def
_pipeline_block_reductions
(
self
,
block_id
,
flat_grads
):
def
_pipeline_block_reductions
(
self
,
block_id
):
self
.
_flatten_grad_mt
(
1.0
/
self
.
_world_size
if
self
.
_predivide
else
1.0
)
start
=
block_id
*
self
.
_block_size
start
=
block_id
*
self
.
_block_size
end
=
start
+
self
.
_block_size
end
=
start
+
self
.
_block_size
grad_block
=
flat_grads
[
start
:
end
]
grad_block
=
self
.
_flat_grads
[
start
:
end
]
grad_shards
=
[
grad_block
[
i
*
self
.
_shard_size
:(
i
+
1
)
*
self
.
_shard_size
]
for
i
in
range
(
self
.
_group_size
)]
works
=
[
None
]
*
self
.
_num_chunks
if
self
.
_internal_pipeline
:
for
chunk
in
range
(
self
.
_num_chunks
):
works
=
[]
grad_chunk
=
grad_block
[
chunk
*
self
.
_chunk_size
:(
chunk
+
1
)
*
self
.
_chunk_size
]
chunk_size
=
self
.
_shard_size
//
self
.
_num_chunks
grad_shards
=
[
grad_chunk
[
i
*
self
.
_shard_size
:(
i
+
1
)
*
self
.
_shard_size
]
for
i
in
range
(
self
.
_group_size
)]
for
i
in
range
(
self
.
_num_chunks
):
rs_stream
=
self
.
_rs_st
[
chunk
%
self
.
_num_rs_pg
]
chunks
=
[
grad_shards
[
j
][
i
*
chunk_size
:(
i
+
1
)
*
chunk_size
]
for
j
in
range
(
self
.
_group_size
)]
rs_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
if
self
.
_pg_supports_no_copy
:
with
torch
.
cuda
.
stream
(
rs_stream
):
work
=
torch
.
distributed
.
reduce_scatter
(
chunks
[
self
.
_rank_in_group
],
chunks
,
group
=
self
.
_rs_pg
[
i
%
len
(
self
.
_rs_pg
)],
async_op
=
True
,
no_copy
=
True
)
work
=
torch
.
distributed
.
reduce_scatter
(
grad_shards
[
self
.
_rank_in_group
],
grad_shards
,
group
=
self
.
_rs_pg
[
chunk
%
self
.
_num_rs_pg
],
async_op
=
True
,
no_copy
=
True
)
else
:
work
=
torch
.
distributed
.
reduce_scatter
(
chunks
[
self
.
_rank_in_group
],
chunks
,
group
=
self
.
_rs_pg
[
i
%
len
(
self
.
_rs_pg
)],
async_op
=
True
)
if
self
.
_num_groups
>
1
:
work
.
wait
()
work
=
torch
.
distributed
.
all_reduce
(
chunks
[
self
.
_rank_in_group
],
group
=
self
.
_ar_pg
[
i
%
len
(
self
.
_ar_pg
)],
async_op
=
True
)
works
.
append
(
work
)
else
:
if
self
.
_pg_supports_no_copy
:
work
=
torch
.
distributed
.
reduce_scatter
(
grad_shards
[
self
.
_rank_in_group
],
grad_shards
,
group
=
self
.
_rs_pg
[
block_id
%
len
(
self
.
_rs_pg
)],
async_op
=
True
,
no_copy
=
True
)
else
:
work
=
torch
.
distributed
.
reduce_scatter
(
grad_shards
[
self
.
_rank_in_group
],
grad_shards
,
group
=
self
.
_rs_pg
[
block_id
%
len
(
self
.
_rs_pg
)],
async_op
=
True
)
works
=
[
work
]
if
self
.
_num_groups
>
1
:
if
self
.
_num_groups
>
1
:
work
.
wait
()
ar_stream
=
self
.
_ar_st
[
chunk
%
self
.
_num_ar_pg
]
works
=
[]
with
torch
.
cuda
.
stream
(
ar_stream
):
chunk_size
=
self
.
_shard_size
//
self
.
_num_chunks
for
i
in
range
(
self
.
_num_chunks
):
chunks
=
[
grad_shards
[
j
][
i
*
chunk_size
:(
i
+
1
)
*
chunk_size
]
for
j
in
range
(
self
.
_group_size
)]
work
=
torch
.
distributed
.
all_reduce
(
chunks
[
self
.
_rank_in_group
],
group
=
self
.
_ar_pg
[
i
%
len
(
self
.
_ar_pg
)],
async_op
=
True
)
works
.
append
(
work
)
if
self
.
_compute_L2_grad_norm
:
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
0
]):
for
work
in
works
:
work
.
wait
()
work
.
wait
()
if
block_id
+
1
==
self
.
_num_blocks
:
work
=
torch
.
distributed
.
all_reduce
(
grad_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ar_pg
[
chunk
%
self
.
_num_ar_pg
],
async_op
=
True
)
self
.
_L2_grad_norm
=
grad_shards
[
self
.
_rank_in_group
].
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
works
[
chunk
]
=
work
elif
block_id
!=
0
:
self
.
_L2_grad_norm
+=
grad_shards
[
self
.
_rank_in_group
].
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
else
:
self
.
_L2_grad_norm
+=
grad_shards
[
self
.
_rank_in_group
].
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
torch
.
distributed
.
all_reduce
(
self
.
_L2_grad_norm
,
group
=
self
.
_rs_pg
[
0
])
self
.
_L2_grad_norm
.
sqrt_
()
for
work
in
works
:
work
.
wait
()
# NB!
if
self
.
_compute_L2_grad_norm
:
# self._global_scale is used by this method.
for
chunk
in
range
(
self
.
_num_chunks
):
grad_chunk
=
grad_block
[
chunk
*
self
.
_chunk_size
:(
chunk
+
1
)
*
self
.
_chunk_size
]
grad_shards
=
[
grad_chunk
[
i
*
self
.
_shard_size
:(
i
+
1
)
*
self
.
_shard_size
]
for
i
in
range
(
self
.
_group_size
)]
with
torch
.
cuda
.
stream
(
self
.
_l2_grad_norm_st
):
works
[
chunk
].
wait
()
l2_grad_sq
=
grad_shards
[
self
.
_rank_in_group
].
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
if
block_id
+
1
==
self
.
_num_blocks
and
chunk
==
0
:
self
.
_L2_grad_norm
=
l2_grad_sq
else
:
self
.
_L2_grad_norm
+=
l2_grad_sq
if
block_id
==
0
and
chunk
+
1
==
self
.
_num_chunks
:
torch
.
distributed
.
all_reduce
(
self
.
_L2_grad_norm
,
group
=
self
.
_l2_grad_norm_pg
)
self
.
_L2_grad_norm
.
sqrt_
()
self
.
_reductions_works
[
block_id
]
=
works
def
_pipeline_block_step
(
self
,
block_id
):
if
self
.
_new_params
is
None
:
self
.
_new_params
=
torch
.
zeros_like
(
self
.
_flat_grads
,
dtype
=
uint8
if
self
.
_e5m2_allgather
else
self
.
_flat_grads
.
dtype
)
def
_pipeline_block_step
(
self
,
block_id
,
flat_grads
,
new_params
):
start
=
block_id
*
self
.
_block_size
start
=
block_id
*
self
.
_block_size
new_params_shards
=
[
new_params
[
start
+
shard_i
*
self
.
_shard_size
:
start
+
(
shard_i
+
1
)
*
self
.
_shard_size
]
for
shard_i
in
range
(
self
.
_group_size
)]
end
=
start
+
self
.
_block_size
self
.
_partial_step_single_shard
(
block_id
)
new_params_block
=
self
.
_new_params
[
start
:
end
]
if
self
.
_pg_supports_no_copy
:
torch
.
distributed
.
all_gather
(
new_params_shards
,
new_params_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ag_pg
[
block_id
%
len
(
self
.
_ag_pg
)],
no_copy
=
True
)
else
:
torch
.
distributed
.
all_gather
(
new_params_shards
,
new_params_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ag_pg
[
block_id
%
len
(
self
.
_ag_pg
)])
def
_pipeline_block
(
self
,
block_id
,
flat_grads
,
new_params
):
works
=
[
None
]
*
self
.
_num_chunks
self
.
_pipeline_block_reductions
(
block_id
,
flat_grads
)
for
chunk
in
range
(
self
.
_num_chunks
):
self
.
_pipeline_block_step
(
block_id
,
flat_grads
,
new_params
)
new_params_chunk
=
new_params_block
[
chunk
*
self
.
_chunk_size
:(
chunk
+
1
)
*
self
.
_chunk_size
]
new_params_shards
=
[
new_params_chunk
[
i
*
self
.
_shard_size
:(
i
+
1
)
*
self
.
_shard_size
]
for
i
in
range
(
self
.
_group_size
)]
ag_stream
=
self
.
_ag_st
[
chunk
%
self
.
_num_ag_pg
]
with
torch
.
cuda
.
stream
(
ag_stream
):
self
.
_reductions_works
[
block_id
][
chunk
].
wait
()
self
.
_partial_step_single_shard
(
block_id
,
chunk
)
work
=
torch
.
distributed
.
all_gather
(
new_params_shards
,
new_params_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ag_pg
[
chunk
%
self
.
_num_ag_pg
],
async_op
=
True
,
no_copy
=
True
)
works
[
chunk
]
=
work
self
.
_allgather_works
[
block_id
]
=
works
def
_flatten_grad_mt
(
self
,
scale
):
def
_flatten_grad_mt
(
self
,
scale
):
if
self
.
_flat_mt
:
if
self
.
_flat_mt
:
...
@@ -281,10 +283,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -281,10 +283,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if
grad
is
not
None
:
if
grad
is
not
None
:
grads
.
append
(
grad
)
grads
.
append
(
grad
)
flat_grads
.
append
(
self
.
_flat_grads
[
grads_info
[
"param_offset"
]:
grads_info
[
"param_offset"
]
+
grads_info
[
"param_grads_size"
]]
)
flat_grads
.
append
(
self
.
_flat_grads
[
grads_info
[
"param_offset"
]:
grads_info
[
"param_offset"
]
+
grads_info
[
"param_grads_size"
]]
)
self
.
_grads
[
p_i
]
=
None
self
.
_grads
=
[
None
]
*
len
(
self
.
_grads_info
)
if
len
(
grads
)
>
0
:
if
len
(
grads
)
>
0
:
import
amp_C
from
apex.multi_tensor_apply
import
multi_tensor_applier
self
.
_overflow_buf
.
zero_
()
self
.
_overflow_buf
.
zero_
()
multi_tensor_applier
(
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
amp_C
.
multi_tensor_scale
,
...
@@ -295,7 +295,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -295,7 +295,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def
_do_overlapped_reduction
(
self
,
param_i
,
param_grads_size
,
param_offset
,
grad
):
def
_do_overlapped_reduction
(
self
,
param_i
,
param_grads_size
,
param_offset
,
grad
):
# handle overlapped reductions
# handle overlapped reductions
if
self
.
_flat_mt
:
if
self
.
_flat_mt
:
self
.
_grads
[
param_i
]
=
grad
self
.
_grads
[
param_i
]
=
grad
.
view
(
-
1
)
else
:
else
:
torch
.
div
(
grad
.
view
(
-
1
),
self
.
_world_size
if
self
.
_predivide
else
1.0
,
out
=
self
.
_flat_grads
[
param_offset
:
param_offset
+
param_grads_size
])
torch
.
div
(
grad
.
view
(
-
1
),
self
.
_world_size
if
self
.
_predivide
else
1.0
,
out
=
self
.
_flat_grads
[
param_offset
:
param_offset
+
param_grads_size
])
self
.
_grads_generated
[
param_i
]
=
True
self
.
_grads_generated
[
param_i
]
=
True
...
@@ -304,19 +304,9 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -304,19 +304,9 @@ class DistributedFusedAdam(torch.optim.Optimizer):
flush_block
=
self
.
_get_flush_block
()
flush_block
=
self
.
_get_flush_block
()
while
flush_block
:
while
flush_block
:
block_id
=
flush_block
[
0
]
//
self
.
_block_size
block_id
=
flush_block
[
0
]
//
self
.
_block_size
self
.
_flatten_grad_mt
(
1.0
/
self
.
_world_size
if
self
.
_predivide
else
1.0
)
self
.
_pipeline_block_reductions
(
block_id
)
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)].
wait_stream
(
torch
.
cuda
.
current_stream
())
if
self
.
_full_pipeline
:
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)]):
self
.
_pipeline_block_step
(
block_id
)
if
self
.
_full_pipeline
:
if
self
.
_new_params
is
None
:
if
self
.
_e5m2_allgather
:
self
.
_new_params
=
torch
.
zeros_like
(
self
.
_flat_grads
,
dtype
=
torch
.
uint8
)
else
:
self
.
_new_params
=
torch
.
zeros_like
(
self
.
_flat_grads
)
self
.
_pipeline_block
(
block_id
,
self
.
_flat_grads
,
self
.
_new_params
)
else
:
self
.
_pipeline_block_reductions
(
block_id
,
self
.
_flat_grads
)
flush_block
=
self
.
_get_flush_block
()
flush_block
=
self
.
_get_flush_block
()
def
set_global_scale
(
self
,
global_scale
):
def
set_global_scale
(
self
,
global_scale
):
...
@@ -360,8 +350,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -360,8 +350,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
@
property
@
property
def
L2_grad_norm
(
self
):
def
L2_grad_norm
(
self
):
if
self
.
_compute_L2_grad_norm
:
if
self
.
_compute_L2_grad_norm
:
for
i
,
blk_st
in
enumerate
(
self
.
_blk_st
):
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_l2_grad_norm_st
)
torch
.
cuda
.
current_stream
().
wait_stream
(
blk_st
)
return
self
.
_L2_grad_norm
return
self
.
_L2_grad_norm
else
:
else
:
return
None
return
None
...
@@ -376,7 +365,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -376,7 +365,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# This means we have to play around with indexes, which requires knowledge of block and shard number.
# This means we have to play around with indexes, which requires knowledge of block and shard number.
# Implement a method that performs a partial update of a single shard within a single block.
# Implement a method that performs a partial update of a single shard within a single block.
def
_partial_step_single_shard
(
self
,
block_id
,
undo
=
False
):
def
_partial_step_single_shard
(
self
,
block_id
,
chunk_id
,
undo
=
False
):
"""Perform step function for a single shard.
"""Perform step function for a single shard.
Arguments:
Arguments:
...
@@ -385,7 +374,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -385,7 +374,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
"""
"""
shard_id
=
self
.
_rank_in_group
shard_id
=
self
.
_rank_in_group
shard_start
=
block_id
*
self
.
_
bloc
k_size
+
shard_id
*
self
.
_shard_size
shard_start
=
(
block_id
*
self
.
_
num_chunks
+
chunk_id
)
*
self
.
_chun
k_size
+
shard_id
*
self
.
_shard_size
shard_end
=
shard_start
+
self
.
_shard_size
shard_end
=
shard_start
+
self
.
_shard_size
if
self
.
_fp32_p
is
None
:
if
self
.
_fp32_p
is
None
:
...
@@ -393,13 +382,13 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -393,13 +382,13 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Allocate fp32 buffers on demand. Note that we don't make these part of the state
# Allocate fp32 buffers on demand. Note that we don't make these part of the state
# since each rank only has partial buffers.
# since each rank only has partial buffers.
# To-Do:
# To-Do:
self
.
_fp32_p
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_shard_size
]).
float
().
cuda
()
self
.
_fp32_p
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_
num_chunks
*
self
.
_
shard_size
]).
float
().
cuda
()
self
.
_fp32_m
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_shard_size
]).
float
().
cuda
()
self
.
_fp32_m
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_
num_chunks
*
self
.
_
shard_size
]).
float
().
cuda
()
self
.
_fp32_v
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_shard_size
]).
float
().
cuda
()
self
.
_fp32_v
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_
num_chunks
*
self
.
_
shard_size
]).
float
().
cuda
()
if
self
.
_revert_method
>
1
:
if
self
.
_revert_method
>
1
:
self
.
_fp32_backup_p
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_shard_size
]).
float
().
cuda
()
self
.
_fp32_backup_p
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_
num_chunks
*
self
.
_
shard_size
]).
float
().
cuda
()
self
.
_fp32_backup_m
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_shard_size
]).
float
().
cuda
()
self
.
_fp32_backup_m
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_
num_chunks
*
self
.
_
shard_size
]).
float
().
cuda
()
self
.
_fp32_backup_v
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_shard_size
]).
float
().
cuda
()
self
.
_fp32_backup_v
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_
num_chunks
*
self
.
_
shard_size
]).
float
().
cuda
()
self
.
_copy_to_fp32
=
True
self
.
_copy_to_fp32
=
True
step
=
None
step
=
None
...
@@ -445,7 +434,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -445,7 +434,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if
self
.
_copy_to_fp32
:
if
self
.
_copy_to_fp32
:
param_offset
=
clipped_start
-
shard_start
param_offset
=
clipped_start
-
shard_start
param_size
=
clipped_end
-
clipped_start
param_size
=
clipped_end
-
clipped_start
buffer_start
=
block_id
*
self
.
_shard_size
+
param_offset
buffer_start
=
(
block_id
*
self
.
_num_chunks
+
chunk_id
)
*
self
.
_shard_size
+
param_offset
buffer_end
=
buffer_start
+
param_size
buffer_end
=
buffer_start
+
param_size
param_start
=
(
clipped_start
-
start
)
param_start
=
(
clipped_start
-
start
)
param_end
=
param_start
+
param_size
param_end
=
param_start
+
param_size
...
@@ -457,7 +446,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -457,7 +446,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
group_offset
=
group_start
-
shard_start
group_offset
=
group_start
-
shard_start
group_shard_start
=
shard_start
+
group_offset
group_shard_start
=
shard_start
+
group_offset
group_shard_end
=
group_shard_start
+
group_size
group_shard_end
=
group_shard_start
+
group_size
group_buffer_start
=
block_id
*
self
.
_shard_size
+
group_offset
group_buffer_start
=
(
block_id
*
self
.
_num_chunks
+
chunk_id
)
*
self
.
_shard_size
+
group_offset
group_buffer_end
=
group_buffer_start
+
group_size
group_buffer_end
=
group_buffer_start
+
group_size
beta1
,
beta2
=
group
[
'betas'
]
beta1
,
beta2
=
group
[
'betas'
]
...
@@ -520,12 +509,11 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -520,12 +509,11 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if
self
.
_last_step
or
not
self
.
_overlap_reductions
:
if
self
.
_last_step
or
not
self
.
_overlap_reductions
:
# nothing done so far, run full pipeline after reductions
# nothing done so far, run full pipeline after reductions
for
inv_block_id
in
range
(
self
.
_num_blocks
):
for
block_id
in
range
(
self
.
_num_blocks
-
1
,
-
1
,
-
1
):
block_id
=
self
.
_num_blocks
-
inv_block_id
-
1
self
.
_pipeline_block_reductions
(
block_id
)
self
.
_flatten_grad_mt
(
1.0
/
self
.
_world_size
if
self
.
_predivide
else
1.0
)
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)].
wait_stream
(
torch
.
cuda
.
current_stream
())
if
self
.
_compute_L2_grad_norm
:
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)]):
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_l2_grad_norm_st
)
self
.
_pipeline_block_reductions
(
block_id
,
self
.
_flat_grads
)
self
.
_copy_to_fp32
=
False
self
.
_copy_to_fp32
=
False
self
.
_decomp_stats
=
None
self
.
_decomp_stats
=
None
...
@@ -536,7 +524,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -536,7 +524,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
"""Revert effect of previously calling partial_step.
"""Revert effect of previously calling partial_step.
"""
"""
for
block_id
in
range
(
self
.
_num_blocks
):
for
block_id
in
range
(
self
.
_num_blocks
):
self
.
_partial_step_single_shard
(
block_id
,
undo
=
True
)
for
chunk
in
range
(
self
.
_num_chunks
):
self
.
_partial_step_single_shard
(
block_id
,
chunk
,
undo
=
True
)
def
step
(
self
,
closure
=
None
,
skip_overflow_check
=
False
):
def
step
(
self
,
closure
=
None
,
skip_overflow_check
=
False
):
loss
=
None
loss
=
None
...
@@ -544,19 +533,13 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -544,19 +533,13 @@ class DistributedFusedAdam(torch.optim.Optimizer):
loss
=
closure
()
loss
=
closure
()
if
self
.
_last_step
or
not
self
.
_overlap_reductions
or
not
self
.
_full_pipeline
:
if
self
.
_last_step
or
not
self
.
_overlap_reductions
or
not
self
.
_full_pipeline
:
if
self
.
_new_params
is
None
:
for
block_id
in
range
(
self
.
_num_blocks
-
1
,
-
1
,
-
1
):
if
self
.
_e5m2_allgather
:
self
.
_pipeline_block_step
(
block_id
)
self
.
_new_params
=
torch
.
zeros_like
(
self
.
_flat_grads
,
dtype
=
torch
.
uint8
)
else
:
self
.
_new_params
=
torch
.
zeros_like
(
self
.
_flat_grads
)
for
inv_block_id
in
range
(
self
.
_num_blocks
):
block_id
=
self
.
_num_blocks
-
inv_block_id
-
1
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)]):
self
.
_pipeline_block_step
(
block_id
,
self
.
_flat_grads
,
self
.
_new_params
)
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
0
]):
with
torch
.
cuda
.
stream
(
self
.
_completion_st
):
for
i
,
blk_st
in
enumerate
(
self
.
_blk_st
):
for
block_id
in
range
(
self
.
_num_blocks
-
1
,
-
1
,
-
1
):
torch
.
cuda
.
current_stream
().
wait_stream
(
blk_st
)
for
chunk
in
range
(
self
.
_num_chunks
):
self
.
_allgather_works
[
block_id
][
chunk
].
wait
()
# Check for overflow
# Check for overflow
# Store state for loss scaler calculation
# Store state for loss scaler calculation
...
@@ -570,7 +553,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -570,7 +553,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
revert_step
()
self
.
revert_step
()
else
:
else
:
# Copy self._new_params to model params
# Copy self._new_params to model params
if
self
.
_e5m2_allgather
:
if
self
.
_e5m2_allgather
or
self
.
_do_not_flatten_model
:
p_in
=
[]
p_in
=
[]
p_out
=
[]
p_out
=
[]
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -585,7 +568,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -585,7 +568,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
state
[
'step'
]
+=
1
state
[
'step'
]
+=
1
nels
=
p
.
numel
()
nels
=
p
.
numel
()
offset
=
self
.
_grads_info
[
param_i
][
'param_offset'
]
offset
=
self
.
_grads_info
[
param_i
][
'param_offset'
]
if
self
.
_e5m2_allgather
:
if
self
.
_e5m2_allgather
or
self
.
_do_not_flatten_model
:
p_in
.
append
(
self
.
_new_params
[
offset
:
offset
+
nels
].
view_as
(
p
))
p_in
.
append
(
self
.
_new_params
[
offset
:
offset
+
nels
].
view_as
(
p
))
p_out
.
append
(
p
)
p_out
.
append
(
p
)
else
:
else
:
...
@@ -596,9 +579,20 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -596,9 +579,20 @@ class DistributedFusedAdam(torch.optim.Optimizer):
fused_adam_cuda
.
unpack_e5m2_mt
,
fused_adam_cuda
.
unpack_e5m2_mt
,
self
.
_overflow_buf
,
self
.
_overflow_buf
,
[
p_in
,
p_out
]);
[
p_in
,
p_out
]);
self
.
_new_params
=
None
elif
self
.
_do_not_flatten_model
:
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
self
.
_overflow_buf
,
[
p_in
,
p_out
],
1.0
);
if
not
self
.
_e5m2_allgather
and
not
self
.
_do_not_flatten_model
:
self
.
_new_params
=
None
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_completion_st
)
self
.
_reductions_works
=
[
None
]
*
self
.
_num_blocks
self
.
_allgather_works
=
[
None
]
*
self
.
_num_blocks
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_blk_st
[
0
])
return
loss
return
loss
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment