Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
b85ff391
Commit
b85ff391
authored
Mar 18, 2020
by
Thor Johnsen
Browse files
Add option to revert step through double buffering
parent
ffed6e80
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
39 additions
and
16 deletions
+39
-16
apex/contrib/optimizers/distributed_fused_adam.py
apex/contrib/optimizers/distributed_fused_adam.py
+39
-16
No files found.
apex/contrib/optimizers/distributed_fused_adam.py
View file @
b85ff391
...
@@ -45,7 +45,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -45,7 +45,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
amp_scale_adjustment
=
1.0
,
overlap_reductions
=
True
,
full_pipeline
=
True
,
amp_scale_adjustment
=
1.0
,
overlap_reductions
=
True
,
full_pipeline
=
True
,
compute_L2_grad_norm
=
False
,
distributed_weight_update
=
0
,
compute_L2_grad_norm
=
False
,
distributed_weight_update
=
0
,
dwu_group_size
=
0
,
dwu_num_blocks
=
4
,
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_group_size
=
0
,
dwu_num_blocks
=
4
,
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_num_ag_pg
=
0
,
dwu_num_blk_st
=
1
):
dwu_num_ag_pg
=
0
,
dwu_num_blk_st
=
1
,
revert_method
=
1
):
global
fused_adam_cuda
global
fused_adam_cuda
fused_adam_cuda
=
importlib
.
import_module
(
"fused_adam_cuda"
)
fused_adam_cuda
=
importlib
.
import_module
(
"fused_adam_cuda"
)
...
@@ -64,6 +64,14 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -64,6 +64,14 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
# Way to revert a step
# 3 -> undo kernel + double buffer (debug, print norm of difference)
# 2 -> double buffer fp32 parameters
# 1 -> undo kernel
self
.
_revert_method
=
revert_method
if
self
.
_revert_method
>
1
:
print
(
"revert_method -> double buffer fp32 parameters, will consume more memory"
)
self
.
_last_step
=
False
self
.
_last_step
=
False
self
.
_overlap_reductions
=
overlap_reductions
self
.
_overlap_reductions
=
overlap_reductions
self
.
_global_scale
=
None
self
.
_global_scale
=
None
...
@@ -314,6 +322,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -314,6 +322,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
_fp32_p
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_shard_size
]).
float
().
cuda
()
self
.
_fp32_p
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_shard_size
]).
float
().
cuda
()
self
.
_fp32_m
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_shard_size
]).
float
().
cuda
()
self
.
_fp32_m
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_shard_size
]).
float
().
cuda
()
self
.
_fp32_v
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_shard_size
]).
float
().
cuda
()
self
.
_fp32_v
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_shard_size
]).
float
().
cuda
()
if
self
.
_revert_method
>
1
:
self
.
_fp32_backup_p
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_shard_size
]).
float
().
cuda
()
self
.
_fp32_backup_m
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_shard_size
]).
float
().
cuda
()
self
.
_fp32_backup_v
=
torch
.
zeros
([
self
.
_num_blocks
*
self
.
_shard_size
]).
float
().
cuda
()
self
.
_copy_to_fp32
=
True
self
.
_copy_to_fp32
=
True
step
=
None
step
=
None
...
@@ -376,21 +388,32 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -376,21 +388,32 @@ class DistributedFusedAdam(torch.optim.Optimizer):
beta1
,
beta2
=
group
[
'betas'
]
beta1
,
beta2
=
group
[
'betas'
]
if
undo
:
if
undo
:
fused_adam_cuda
.
adam_undo
(
if
self
.
_revert_method
==
1
:
self
.
_fp32_p
[
group_buffer_start
:
group_buffer_end
],
fused_adam_cuda
.
adam_undo
(
self
.
_fp32_m
[
group_buffer_start
:
group_buffer_end
],
self
.
_fp32_p
[
group_buffer_start
:
group_buffer_end
],
self
.
_fp32_v
[
group_buffer_start
:
group_buffer_end
],
self
.
_fp32_m
[
group_buffer_start
:
group_buffer_end
],
self
.
_flat_grads
[
group_shard_start
:
group_shard_end
],
self
.
_fp32_v
[
group_buffer_start
:
group_buffer_end
],
group
[
'lr'
],
self
.
_flat_grads
[
group_shard_start
:
group_shard_end
],
beta1
,
group
[
'lr'
],
beta2
,
beta1
,
group
[
'eps'
],
beta2
,
combined_scale
,
group
[
'eps'
],
step
+
1
,
# FIXME: Verify this should be step+1
combined_scale
,
self
.
eps_mode
,
step
+
1
,
# FIXME: Verify this should be step+1
bias_correction
,
self
.
eps_mode
,
group
[
'weight_decay'
])
bias_correction
,
group
[
'weight_decay'
])
elif
self
.
_revert_method
==
2
:
self
.
_fp32_p
[
group_buffer_start
:
group_buffer_end
].
copy_
(
self
.
_fp32_backup_p
[
group_buffer_start
:
group_buffer_end
])
self
.
_fp32_m
[
group_buffer_start
:
group_buffer_end
].
copy_
(
self
.
_fp32_backup_m
[
group_buffer_start
:
group_buffer_end
])
self
.
_fp32_v
[
group_buffer_start
:
group_buffer_end
].
copy_
(
self
.
_fp32_backup_v
[
group_buffer_start
:
group_buffer_end
])
elif
self
.
_revert_method
==
3
:
raise
RuntimeError
(
'revert_step debug option not implemented yet'
)
else
:
else
:
if
self
.
_revert_method
>
1
:
self
.
_fp32_backup_p
[
group_buffer_start
:
group_buffer_end
].
copy_
(
self
.
_fp32_p
[
group_buffer_start
:
group_buffer_end
])
self
.
_fp32_backup_m
[
group_buffer_start
:
group_buffer_end
].
copy_
(
self
.
_fp32_m
[
group_buffer_start
:
group_buffer_end
])
self
.
_fp32_backup_v
[
group_buffer_start
:
group_buffer_end
].
copy_
(
self
.
_fp32_v
[
group_buffer_start
:
group_buffer_end
])
fused_adam_cuda
.
adam
(
fused_adam_cuda
.
adam
(
self
.
_fp32_p
[
group_buffer_start
:
group_buffer_end
],
self
.
_fp32_p
[
group_buffer_start
:
group_buffer_end
],
self
.
_new_params
[
group_shard_start
:
group_shard_end
],
self
.
_new_params
[
group_shard_start
:
group_shard_end
],
...
@@ -412,7 +435,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -412,7 +435,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
for
block
in
range
(
self
.
_num_blocks
):
for
block
in
range
(
self
.
_num_blocks
):
grad_block
=
self
.
_flat_grads
[
block
*
self
.
_block_size
:(
block
+
1
)
*
self
.
_block_size
]
grad_block
=
self
.
_flat_grads
[
block
*
self
.
_block_size
:(
block
+
1
)
*
self
.
_block_size
]
grad_shards
=
[
grad_block
[
i
*
self
.
_shard_size
:(
i
+
1
)
*
self
.
_shard_size
]
for
i
in
range
(
self
.
_group_size
)]
grad_shards
=
[
grad_block
[
i
*
self
.
_shard_size
:(
i
+
1
)
*
self
.
_shard_size
]
for
i
in
range
(
self
.
_group_size
)]
shard_grad_norm
=
grad_shards
[
self
.
_rank_in_group
].
float
().
norm
(
)
shard_grad_norm
=
grad_shards
[
self
.
_rank_in_group
].
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
partial_sum
+=
(
shard_grad_norm
*
shard_grad_norm
)
partial_sum
+=
(
shard_grad_norm
*
shard_grad_norm
)
torch
.
distributed
.
all_reduce
(
partial_sum
,
group
=
self
.
_rs_pg
[
0
],
async_op
=
False
)
torch
.
distributed
.
all_reduce
(
partial_sum
,
group
=
self
.
_rs_pg
[
0
],
async_op
=
False
)
self
.
_L2_grad_norm
=
partial_sum
.
sqrt
().
item
()
self
.
_L2_grad_norm
=
partial_sum
.
sqrt
().
item
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment