Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
91a5a87e
Commit
91a5a87e
authored
May 06, 2020
by
Thor Johnsen
Browse files
Slight improvements
parent
25c80afe
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
32 additions
and
52 deletions
+32
-52
apex/contrib/optimizers/distributed_fused_adam.py
apex/contrib/optimizers/distributed_fused_adam.py
+9
-15
apex/contrib/optimizers/distributed_fused_adam_v2.py
apex/contrib/optimizers/distributed_fused_adam_v2.py
+9
-15
apex/contrib/optimizers/distributed_fused_adam_v3.py
apex/contrib/optimizers/distributed_fused_adam_v3.py
+14
-22
No files found.
apex/contrib/optimizers/distributed_fused_adam.py
View file @
91a5a87e
...
@@ -67,6 +67,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -67,6 +67,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
eps_mode
=
0
if
eps_inside_sqrt
else
1
self
.
eps_mode
=
0
if
eps_inside_sqrt
else
1
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
_has_overflow
=
False
assert
(
len
(
self
.
param_groups
)
==
1
),
"More than one parameter group is not supported."
assert
(
len
(
self
.
param_groups
)
==
1
),
"More than one parameter group is not supported."
...
@@ -299,7 +300,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -299,7 +300,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def
_get_flush_block
(
self
):
def
_get_flush_block
(
self
):
flush_block
=
[]
flush_block
=
[]
if
self
.
_grads_generated
[
self
.
_low_param_i
[
self
.
_current_block
-
1
]]:
if
self
.
_current_block
>
0
and
self
.
_grads_generated
[
self
.
_low_param_i
[
self
.
_current_block
-
1
]]:
num_grads
=
len
(
self
.
_grads_generated
)
num_grads
=
len
(
self
.
_grads_generated
)
contiguous_idx
=
num_grads
contiguous_idx
=
num_grads
while
contiguous_idx
>
0
and
self
.
_grads_generated
[
contiguous_idx
-
1
]:
while
contiguous_idx
>
0
and
self
.
_grads_generated
[
contiguous_idx
-
1
]:
...
@@ -311,10 +312,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -311,10 +312,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
end
=
(
self
.
_current_block
+
1
)
*
self
.
_block_size
end
=
(
self
.
_current_block
+
1
)
*
self
.
_block_size
flush_block
=
[
start
,
end
]
flush_block
=
[
start
,
end
]
if
self
.
_current_block
==
0
:
# reset
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
return
flush_block
return
flush_block
def
_pipeline_block_reductions
(
self
,
block_id
):
def
_pipeline_block_reductions
(
self
,
block_id
):
...
@@ -351,7 +348,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -351,7 +348,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
l2_grad_norm_sq
=
torch
.
empty
([
1
],
device
=
'cuda'
)
l2_grad_norm_sq
=
torch
.
empty
([
1
],
device
=
'cuda'
)
l2_grad_norm_sq
=
self
.
_fp16_g
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
l2_grad_norm_sq
=
self
.
_fp16_g
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
torch
.
distributed
.
all_reduce
(
l2_grad_norm_sq
,
group
=
self
.
_l2_grad_norm_pg
)
torch
.
distributed
.
all_reduce
(
l2_grad_norm_sq
,
group
=
self
.
_l2_grad_norm_pg
)
self
.
_L2_grad_norm
=
l2_grad_norm_sq
.
sqrt
()
self
.
_L2_grad_norm
=
l2_grad_norm_sq
.
sqrt
()
.
item
()
def
__launch_step_kernel
(
self
,
p
,
p_copy
,
m
,
v
,
g
):
def
__launch_step_kernel
(
self
,
p
,
p_copy
,
m
,
v
,
g
):
combined_scale
=
self
.
_global_scale
combined_scale
=
self
.
_global_scale
...
@@ -448,8 +445,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -448,8 +445,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
"""Check if overflows were detected by any call to step(...) method.
"""Check if overflows were detected by any call to step(...) method.
Clears the overflow flag.
Clears the overflow flag.
"""
"""
has_overflow
=
self
.
_overflow
_buf
.
item
()
has_overflow
=
self
.
_
has_
overflow
self
.
_overflow
_buf
.
zero_
()
self
.
_
has_
overflow
=
False
return
has_overflow
return
has_overflow
@
property
@
property
...
@@ -457,7 +454,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -457,7 +454,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
"""Check if overflows were detected by any call to step(...) method.
"""Check if overflows were detected by any call to step(...) method.
Does not clear overflow flag.
Does not clear overflow flag.
"""
"""
return
self
.
_overflow
_buf
.
item
()
return
self
.
_
has_
overflow
def
strided_check_finite
(
self
,
output_params
,
stride
=
1
,
start
=-
1
,
end
=-
1
,
clear
=
True
):
def
strided_check_finite
(
self
,
output_params
,
stride
=
1
,
start
=-
1
,
end
=-
1
,
clear
=
True
):
"""Strided check for overflow.
"""Strided check for overflow.
...
@@ -471,6 +468,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -471,6 +468,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
out_p
,
out_p
,
stride
,
stride
,
1
if
clear
else
0
)
1
if
clear
else
0
)
self
.
_has_overflow
=
False
if
self
.
_overflow_buf
.
item
()
==
0
else
True
return
self
.
_has_overflow
@
property
@
property
def
L2_grad_norm
(
self
):
def
L2_grad_norm
(
self
):
...
@@ -542,13 +541,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -542,13 +541,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
with
torch
.
cuda
.
stream
(
self
.
_completion_st
):
with
torch
.
cuda
.
stream
(
self
.
_completion_st
):
# Check for overflow
# Check for overflow
# Store state for loss scaler calculation
# Store state for loss scaler calculation
if
skip_overflow_check
:
has_overflow
=
False
if
skip_overflow_check
else
self
.
strided_check_finite
(
self
.
_new_params
,
stride
=
self
.
_shard_size
,
start
=
0
,
end
=
self
.
_net_total_param_size
)
has_overflow
=
False
else
:
self
.
strided_check_finite
(
self
.
_new_params
,
stride
=
self
.
_shard_size
,
start
=
0
,
end
=
self
.
_net_total_param_size
)
has_overflow
=
self
.
peek_overflow
if
has_overflow
:
if
has_overflow
:
print
(
"Reverting step"
)
self
.
revert_step
()
self
.
revert_step
()
else
:
else
:
# Copy self._new_params to model params
# Copy self._new_params to model params
...
...
apex/contrib/optimizers/distributed_fused_adam_v2.py
View file @
91a5a87e
...
@@ -67,6 +67,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
...
@@ -67,6 +67,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
self
.
eps_mode
=
0
if
eps_inside_sqrt
else
1
self
.
eps_mode
=
0
if
eps_inside_sqrt
else
1
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
_has_overflow
=
False
assert
(
len
(
self
.
param_groups
)
==
1
),
"More than one parameter group is not supported."
assert
(
len
(
self
.
param_groups
)
==
1
),
"More than one parameter group is not supported."
...
@@ -352,7 +353,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
...
@@ -352,7 +353,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
def
_get_flush_block
(
self
):
def
_get_flush_block
(
self
):
flush_block
=
[]
flush_block
=
[]
if
self
.
_grads_generated
[
self
.
_low_param_i
[
self
.
_current_block
-
1
]]:
if
self
.
_current_block
>
0
and
self
.
_grads_generated
[
self
.
_low_param_i
[
self
.
_current_block
-
1
]]:
num_grads
=
len
(
self
.
_grads_generated
)
num_grads
=
len
(
self
.
_grads_generated
)
contiguous_idx
=
num_grads
contiguous_idx
=
num_grads
while
contiguous_idx
>
0
and
self
.
_grads_generated
[
contiguous_idx
-
1
]:
while
contiguous_idx
>
0
and
self
.
_grads_generated
[
contiguous_idx
-
1
]:
...
@@ -364,10 +365,6 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
...
@@ -364,10 +365,6 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
end
=
(
self
.
_current_block
+
1
)
*
self
.
_block_size
end
=
(
self
.
_current_block
+
1
)
*
self
.
_block_size
flush_block
=
[
start
,
end
]
flush_block
=
[
start
,
end
]
if
self
.
_current_block
==
0
:
# reset
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
return
flush_block
return
flush_block
def
_pipeline_block_reductions
(
self
,
block_id
):
def
_pipeline_block_reductions
(
self
,
block_id
):
...
@@ -404,7 +401,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
...
@@ -404,7 +401,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
l2_grad_norm_sq
=
torch
.
empty
([
1
],
device
=
'cuda'
)
l2_grad_norm_sq
=
torch
.
empty
([
1
],
device
=
'cuda'
)
l2_grad_norm_sq
=
self
.
_fp16_g
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
l2_grad_norm_sq
=
self
.
_fp16_g
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
torch
.
distributed
.
all_reduce
(
l2_grad_norm_sq
,
group
=
self
.
_l2_grad_norm_pg
)
torch
.
distributed
.
all_reduce
(
l2_grad_norm_sq
,
group
=
self
.
_l2_grad_norm_pg
)
self
.
_L2_grad_norm
=
l2_grad_norm_sq
.
sqrt
()
self
.
_L2_grad_norm
=
l2_grad_norm_sq
.
sqrt
()
.
item
()
def
__launch_step_kernel
(
self
,
p
,
p_copy
,
m
,
v
,
g
):
def
__launch_step_kernel
(
self
,
p
,
p_copy
,
m
,
v
,
g
):
combined_scale
=
self
.
_global_scale
combined_scale
=
self
.
_global_scale
...
@@ -501,8 +498,8 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
...
@@ -501,8 +498,8 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
"""Check if overflows were detected by any call to step(...) method.
"""Check if overflows were detected by any call to step(...) method.
Clears the overflow flag.
Clears the overflow flag.
"""
"""
has_overflow
=
self
.
_overflow
_buf
.
item
()
has_overflow
=
self
.
_
has_
overflow
self
.
_overflow
_buf
.
zero_
()
self
.
_
has_
overflow
=
False
return
has_overflow
return
has_overflow
@
property
@
property
...
@@ -510,7 +507,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
...
@@ -510,7 +507,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
"""Check if overflows were detected by any call to step(...) method.
"""Check if overflows were detected by any call to step(...) method.
Does not clear overflow flag.
Does not clear overflow flag.
"""
"""
return
self
.
_overflow
_buf
.
item
()
return
self
.
_
has_
overflow
def
strided_check_finite
(
self
,
output_params
,
stride
=
1
,
start
=-
1
,
end
=-
1
,
clear
=
True
):
def
strided_check_finite
(
self
,
output_params
,
stride
=
1
,
start
=-
1
,
end
=-
1
,
clear
=
True
):
"""Strided check for overflow.
"""Strided check for overflow.
...
@@ -524,6 +521,8 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
...
@@ -524,6 +521,8 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
out_p
,
out_p
,
stride
,
stride
,
1
if
clear
else
0
)
1
if
clear
else
0
)
self
.
_has_overflow
=
False
if
self
.
_overflow_buf
.
item
()
==
0
else
True
return
self
.
_has_overflow
@
property
@
property
def
L2_grad_norm
(
self
):
def
L2_grad_norm
(
self
):
...
@@ -595,13 +594,8 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
...
@@ -595,13 +594,8 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
with
torch
.
cuda
.
stream
(
self
.
_completion_st
):
with
torch
.
cuda
.
stream
(
self
.
_completion_st
):
# Check for overflow
# Check for overflow
# Store state for loss scaler calculation
# Store state for loss scaler calculation
if
skip_overflow_check
:
has_overflow
=
False
if
skip_overflow_check
else
self
.
strided_check_finite
(
self
.
_new_params
,
stride
=
self
.
_shard_size
,
start
=
0
,
end
=
self
.
_net_total_param_size
)
has_overflow
=
False
else
:
self
.
strided_check_finite
(
self
.
_new_params
,
stride
=
self
.
_shard_size
,
start
=
0
,
end
=
self
.
_net_total_param_size
)
has_overflow
=
self
.
peek_overflow
if
has_overflow
:
if
has_overflow
:
print
(
"Reverting step"
)
self
.
revert_step
()
self
.
revert_step
()
else
:
else
:
# Copy self._new_params to model params
# Copy self._new_params to model params
...
...
apex/contrib/optimizers/distributed_fused_adam_v3.py
View file @
91a5a87e
...
@@ -86,7 +86,6 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
...
@@ -86,7 +86,6 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
self
.
_e5m2_allgather
=
e5m2_allgather
self
.
_e5m2_allgather
=
e5m2_allgather
self
.
_do_not_flatten_model
=
do_not_flatten_model
self
.
_do_not_flatten_model
=
do_not_flatten_model
self
.
_full_pipeline
=
full_pipeline
self
.
_full_pipeline
=
full_pipeline
self
.
_compute_L2_grad_norm
=
compute_L2_grad_norm
self
.
_L2_grad_norm
=
None
self
.
_L2_grad_norm
=
None
self
.
_group_size
=
torch
.
cuda
.
device_count
()
if
dwu_group_size
<=
0
else
dwu_group_size
self
.
_group_size
=
torch
.
cuda
.
device_count
()
if
dwu_group_size
<=
0
else
dwu_group_size
self
.
_world_size
=
torch
.
distributed
.
get_world_size
()
self
.
_world_size
=
torch
.
distributed
.
get_world_size
()
...
@@ -202,7 +201,7 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
...
@@ -202,7 +201,7 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
def
_get_flush_block
(
self
):
def
_get_flush_block
(
self
):
flush_block
=
[]
flush_block
=
[]
if
self
.
_grads_generated
[
self
.
_low_param_i
[
self
.
_current_block
-
1
]]:
if
self
.
_current_block
>
0
and
self
.
_grads_generated
[
self
.
_low_param_i
[
self
.
_current_block
-
1
]]:
num_grads
=
len
(
self
.
_grads_generated
)
num_grads
=
len
(
self
.
_grads_generated
)
contiguous_idx
=
num_grads
contiguous_idx
=
num_grads
while
contiguous_idx
>
0
and
self
.
_grads_generated
[
contiguous_idx
-
1
]:
while
contiguous_idx
>
0
and
self
.
_grads_generated
[
contiguous_idx
-
1
]:
...
@@ -214,10 +213,6 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
...
@@ -214,10 +213,6 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
end
=
(
self
.
_current_block
+
1
)
*
self
.
_block_size
end
=
(
self
.
_current_block
+
1
)
*
self
.
_block_size
flush_block
=
[
start
,
end
]
flush_block
=
[
start
,
end
]
if
self
.
_current_block
==
0
:
# reset
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
return
flush_block
return
flush_block
def
__launch_step_kernel
(
self
,
p
,
p_copy
,
m
,
v
,
g
):
def
__launch_step_kernel
(
self
,
p
,
p_copy
,
m
,
v
,
g
):
...
@@ -267,7 +262,7 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
...
@@ -267,7 +262,7 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
if
block_id
==
0
:
if
block_id
==
0
:
self
.
_l2_grad_norm_st
.
wait_stream
(
self
.
_dwu_st
)
self
.
_l2_grad_norm_st
.
wait_stream
(
self
.
_dwu_st
)
with
torch
.
cuda
.
stream
(
self
.
_l2_grad_norm_st
):
with
torch
.
cuda
.
stream
(
self
.
_l2_grad_norm_st
):
self
.
_L2_grad_norm
=
self
.
_flat_grads
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
self
.
_L2_grad_norm
=
self
.
_flat_grads
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
.
item
()
flush_block
=
self
.
_get_flush_block
()
flush_block
=
self
.
_get_flush_block
()
def
set_global_scale
(
self
,
global_scale
):
def
set_global_scale
(
self
,
global_scale
):
...
@@ -303,7 +298,7 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
...
@@ -303,7 +298,7 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
torch
.
distributed
.
all_reduce
(
self
.
_flat_grads
)
torch
.
distributed
.
all_reduce
(
self
.
_flat_grads
)
self
.
_l2_grad_norm_st
.
wait_stream
(
self
.
_dwu_st
)
self
.
_l2_grad_norm_st
.
wait_stream
(
self
.
_dwu_st
)
with
torch
.
cuda
.
stream
(
self
.
_l2_grad_norm_st
):
with
torch
.
cuda
.
stream
(
self
.
_l2_grad_norm_st
):
self
.
_L2_grad_norm
=
self
.
_flat_grads
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
self
.
_L2_grad_norm
=
self
.
_flat_grads
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
.
item
()
self
.
_current_block
=
self
.
_num_blocks
self
.
_current_block
=
self
.
_num_blocks
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
...
@@ -313,7 +308,6 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
...
@@ -313,7 +308,6 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
if
closure
is
not
None
:
if
closure
is
not
None
:
loss
=
closure
()
loss
=
closure
()
if
not
self
.
has_overflow
:
with
torch
.
cuda
.
stream
(
self
.
_dwu_st
):
with
torch
.
cuda
.
stream
(
self
.
_dwu_st
):
self
.
__launch_step_kernel
(
self
.
__launch_step_kernel
(
self
.
_fp32_p
,
self
.
_fp32_p
,
...
@@ -325,8 +319,6 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
...
@@ -325,8 +319,6 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
for
p
in
self
.
_model_params
:
self
.
state
[
p
][
'step'
]
+=
1
for
p
in
self
.
_model_params
:
self
.
state
[
p
][
'step'
]
+=
1
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_dwu_st
)
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_dwu_st
)
else
:
print
(
"Overflow detected, skipping step"
)
return
loss
return
loss
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment