Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
91a5a87e
Commit
91a5a87e
authored
May 06, 2020
by
Thor Johnsen
Browse files
Slight improvements
parent
25c80afe
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
32 additions
and
52 deletions
+32
-52
apex/contrib/optimizers/distributed_fused_adam.py
apex/contrib/optimizers/distributed_fused_adam.py
+9
-15
apex/contrib/optimizers/distributed_fused_adam_v2.py
apex/contrib/optimizers/distributed_fused_adam_v2.py
+9
-15
apex/contrib/optimizers/distributed_fused_adam_v3.py
apex/contrib/optimizers/distributed_fused_adam_v3.py
+14
-22
No files found.
apex/contrib/optimizers/distributed_fused_adam.py
View file @
91a5a87e
...
...
@@ -67,6 +67,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
eps_mode
=
0
if
eps_inside_sqrt
else
1
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
_has_overflow
=
False
assert
(
len
(
self
.
param_groups
)
==
1
),
"More than one parameter group is not supported."
...
...
@@ -299,7 +300,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def
_get_flush_block
(
self
):
flush_block
=
[]
if
self
.
_grads_generated
[
self
.
_low_param_i
[
self
.
_current_block
-
1
]]:
if
self
.
_current_block
>
0
and
self
.
_grads_generated
[
self
.
_low_param_i
[
self
.
_current_block
-
1
]]:
num_grads
=
len
(
self
.
_grads_generated
)
contiguous_idx
=
num_grads
while
contiguous_idx
>
0
and
self
.
_grads_generated
[
contiguous_idx
-
1
]:
...
...
@@ -311,10 +312,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
end
=
(
self
.
_current_block
+
1
)
*
self
.
_block_size
flush_block
=
[
start
,
end
]
if
self
.
_current_block
==
0
:
# reset
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
return
flush_block
def
_pipeline_block_reductions
(
self
,
block_id
):
...
...
@@ -351,7 +348,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
l2_grad_norm_sq
=
torch
.
empty
([
1
],
device
=
'cuda'
)
l2_grad_norm_sq
=
self
.
_fp16_g
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
torch
.
distributed
.
all_reduce
(
l2_grad_norm_sq
,
group
=
self
.
_l2_grad_norm_pg
)
self
.
_L2_grad_norm
=
l2_grad_norm_sq
.
sqrt
()
self
.
_L2_grad_norm
=
l2_grad_norm_sq
.
sqrt
()
.
item
()
def
__launch_step_kernel
(
self
,
p
,
p_copy
,
m
,
v
,
g
):
combined_scale
=
self
.
_global_scale
...
...
@@ -448,8 +445,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
"""Check if overflows were detected by any call to step(...) method.
Clears the overflow flag.
"""
has_overflow
=
self
.
_overflow
_buf
.
item
()
self
.
_overflow
_buf
.
zero_
()
has_overflow
=
self
.
_
has_
overflow
self
.
_
has_
overflow
=
False
return
has_overflow
@
property
...
...
@@ -457,7 +454,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
"""Check if overflows were detected by any call to step(...) method.
Does not clear overflow flag.
"""
return
self
.
_overflow
_buf
.
item
()
return
self
.
_
has_
overflow
def
strided_check_finite
(
self
,
output_params
,
stride
=
1
,
start
=-
1
,
end
=-
1
,
clear
=
True
):
"""Strided check for overflow.
...
...
@@ -471,6 +468,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
out_p
,
stride
,
1
if
clear
else
0
)
self
.
_has_overflow
=
False
if
self
.
_overflow_buf
.
item
()
==
0
else
True
return
self
.
_has_overflow
@
property
def
L2_grad_norm
(
self
):
...
...
@@ -542,13 +541,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
with
torch
.
cuda
.
stream
(
self
.
_completion_st
):
# Check for overflow
# Store state for loss scaler calculation
if
skip_overflow_check
:
has_overflow
=
False
else
:
self
.
strided_check_finite
(
self
.
_new_params
,
stride
=
self
.
_shard_size
,
start
=
0
,
end
=
self
.
_net_total_param_size
)
has_overflow
=
self
.
peek_overflow
has_overflow
=
False
if
skip_overflow_check
else
self
.
strided_check_finite
(
self
.
_new_params
,
stride
=
self
.
_shard_size
,
start
=
0
,
end
=
self
.
_net_total_param_size
)
if
has_overflow
:
print
(
"Reverting step"
)
self
.
revert_step
()
else
:
# Copy self._new_params to model params
...
...
apex/contrib/optimizers/distributed_fused_adam_v2.py
View file @
91a5a87e
...
...
@@ -67,6 +67,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
self
.
eps_mode
=
0
if
eps_inside_sqrt
else
1
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
_has_overflow
=
False
assert
(
len
(
self
.
param_groups
)
==
1
),
"More than one parameter group is not supported."
...
...
@@ -352,7 +353,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
def
_get_flush_block
(
self
):
flush_block
=
[]
if
self
.
_grads_generated
[
self
.
_low_param_i
[
self
.
_current_block
-
1
]]:
if
self
.
_current_block
>
0
and
self
.
_grads_generated
[
self
.
_low_param_i
[
self
.
_current_block
-
1
]]:
num_grads
=
len
(
self
.
_grads_generated
)
contiguous_idx
=
num_grads
while
contiguous_idx
>
0
and
self
.
_grads_generated
[
contiguous_idx
-
1
]:
...
...
@@ -364,10 +365,6 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
end
=
(
self
.
_current_block
+
1
)
*
self
.
_block_size
flush_block
=
[
start
,
end
]
if
self
.
_current_block
==
0
:
# reset
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
return
flush_block
def
_pipeline_block_reductions
(
self
,
block_id
):
...
...
@@ -404,7 +401,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
l2_grad_norm_sq
=
torch
.
empty
([
1
],
device
=
'cuda'
)
l2_grad_norm_sq
=
self
.
_fp16_g
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
torch
.
distributed
.
all_reduce
(
l2_grad_norm_sq
,
group
=
self
.
_l2_grad_norm_pg
)
self
.
_L2_grad_norm
=
l2_grad_norm_sq
.
sqrt
()
self
.
_L2_grad_norm
=
l2_grad_norm_sq
.
sqrt
()
.
item
()
def
__launch_step_kernel
(
self
,
p
,
p_copy
,
m
,
v
,
g
):
combined_scale
=
self
.
_global_scale
...
...
@@ -501,8 +498,8 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
"""Check if overflows were detected by any call to step(...) method.
Clears the overflow flag.
"""
has_overflow
=
self
.
_overflow
_buf
.
item
()
self
.
_overflow
_buf
.
zero_
()
has_overflow
=
self
.
_
has_
overflow
self
.
_
has_
overflow
=
False
return
has_overflow
@
property
...
...
@@ -510,7 +507,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
"""Check if overflows were detected by any call to step(...) method.
Does not clear overflow flag.
"""
return
self
.
_overflow
_buf
.
item
()
return
self
.
_
has_
overflow
def
strided_check_finite
(
self
,
output_params
,
stride
=
1
,
start
=-
1
,
end
=-
1
,
clear
=
True
):
"""Strided check for overflow.
...
...
@@ -524,6 +521,8 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
out_p
,
stride
,
1
if
clear
else
0
)
self
.
_has_overflow
=
False
if
self
.
_overflow_buf
.
item
()
==
0
else
True
return
self
.
_has_overflow
@
property
def
L2_grad_norm
(
self
):
...
...
@@ -595,13 +594,8 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
with
torch
.
cuda
.
stream
(
self
.
_completion_st
):
# Check for overflow
# Store state for loss scaler calculation
if
skip_overflow_check
:
has_overflow
=
False
else
:
self
.
strided_check_finite
(
self
.
_new_params
,
stride
=
self
.
_shard_size
,
start
=
0
,
end
=
self
.
_net_total_param_size
)
has_overflow
=
self
.
peek_overflow
has_overflow
=
False
if
skip_overflow_check
else
self
.
strided_check_finite
(
self
.
_new_params
,
stride
=
self
.
_shard_size
,
start
=
0
,
end
=
self
.
_net_total_param_size
)
if
has_overflow
:
print
(
"Reverting step"
)
self
.
revert_step
()
else
:
# Copy self._new_params to model params
...
...
apex/contrib/optimizers/distributed_fused_adam_v3.py
View file @
91a5a87e
...
...
@@ -86,7 +86,6 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
self
.
_e5m2_allgather
=
e5m2_allgather
self
.
_do_not_flatten_model
=
do_not_flatten_model
self
.
_full_pipeline
=
full_pipeline
self
.
_compute_L2_grad_norm
=
compute_L2_grad_norm
self
.
_L2_grad_norm
=
None
self
.
_group_size
=
torch
.
cuda
.
device_count
()
if
dwu_group_size
<=
0
else
dwu_group_size
self
.
_world_size
=
torch
.
distributed
.
get_world_size
()
...
...
@@ -202,7 +201,7 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
def
_get_flush_block
(
self
):
flush_block
=
[]
if
self
.
_grads_generated
[
self
.
_low_param_i
[
self
.
_current_block
-
1
]]:
if
self
.
_current_block
>
0
and
self
.
_grads_generated
[
self
.
_low_param_i
[
self
.
_current_block
-
1
]]:
num_grads
=
len
(
self
.
_grads_generated
)
contiguous_idx
=
num_grads
while
contiguous_idx
>
0
and
self
.
_grads_generated
[
contiguous_idx
-
1
]:
...
...
@@ -214,10 +213,6 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
end
=
(
self
.
_current_block
+
1
)
*
self
.
_block_size
flush_block
=
[
start
,
end
]
if
self
.
_current_block
==
0
:
# reset
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
return
flush_block
def
__launch_step_kernel
(
self
,
p
,
p_copy
,
m
,
v
,
g
):
...
...
@@ -267,7 +262,7 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
if
block_id
==
0
:
self
.
_l2_grad_norm_st
.
wait_stream
(
self
.
_dwu_st
)
with
torch
.
cuda
.
stream
(
self
.
_l2_grad_norm_st
):
self
.
_L2_grad_norm
=
self
.
_flat_grads
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
self
.
_L2_grad_norm
=
self
.
_flat_grads
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
.
item
()
flush_block
=
self
.
_get_flush_block
()
def
set_global_scale
(
self
,
global_scale
):
...
...
@@ -303,7 +298,7 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
torch
.
distributed
.
all_reduce
(
self
.
_flat_grads
)
self
.
_l2_grad_norm_st
.
wait_stream
(
self
.
_dwu_st
)
with
torch
.
cuda
.
stream
(
self
.
_l2_grad_norm_st
):
self
.
_L2_grad_norm
=
self
.
_flat_grads
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
self
.
_L2_grad_norm
=
self
.
_flat_grads
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
.
item
()
self
.
_current_block
=
self
.
_num_blocks
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
...
...
@@ -313,20 +308,17 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
if
closure
is
not
None
:
loss
=
closure
()
if
not
self
.
has_overflow
:
with
torch
.
cuda
.
stream
(
self
.
_dwu_st
):
self
.
__launch_step_kernel
(
self
.
_fp32_p
,
self
.
_flat_params_shards
[
self
.
_rank_in_group
],
self
.
_fp32_m
,
self
.
_fp32_v
,
self
.
_flat_grads_shards
[
self
.
_rank_in_group
])
torch
.
distributed
.
all_gather
(
self
.
_flat_params_shards
,
self
.
_flat_params_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ag_pg
,
no_copy
=
True
)
for
p
in
self
.
_model_params
:
self
.
state
[
p
][
'step'
]
+=
1
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_dwu_st
)
else
:
print
(
"Overflow detected, skipping step"
)
with
torch
.
cuda
.
stream
(
self
.
_dwu_st
):
self
.
__launch_step_kernel
(
self
.
_fp32_p
,
self
.
_flat_params_shards
[
self
.
_rank_in_group
],
self
.
_fp32_m
,
self
.
_fp32_v
,
self
.
_flat_grads_shards
[
self
.
_rank_in_group
])
torch
.
distributed
.
all_gather
(
self
.
_flat_params_shards
,
self
.
_flat_params_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ag_pg
,
no_copy
=
True
)
for
p
in
self
.
_model_params
:
self
.
state
[
p
][
'step'
]
+=
1
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_dwu_st
)
return
loss
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment