Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
feb93a2a
Commit
feb93a2a
authored
Apr 02, 2020
by
Kexin Yu
Browse files
check empty lists
parent
8e5699e4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
14 deletions
+9
-14
apex/contrib/optimizers/fused_lamb.py
apex/contrib/optimizers/fused_lamb.py
+9
-14
No files found.
apex/contrib/optimizers/fused_lamb.py
View file @
feb93a2a
...
...
@@ -83,7 +83,6 @@ class FusedLAMB(torch.optim.Optimizer):
self
.
adam_w_mode
=
1
if
adam_w_mode
else
0
self
.
set_grad_none
=
set_grad_none
print
(
"apex.contrib.optimiziers.FusedLAMB: testing global gradient clipping"
)
def
zero_grad
(
self
):
if
self
.
set_grad_none
:
...
...
@@ -117,24 +116,20 @@ class FusedLAMB(torch.optim.Optimizer):
else
:
raise
RuntimeError
(
'FusedLAMB only support fp16 and fp32.'
)
print
(
"====after collect"
)
print
(
"====g_all_32:"
,
g_all_32
)
print
(
"====g_all_16:"
,
g_all_16
)
g_norm_32
,
g_norm_16
=
0.0
,
0.0
# compute grad norm for two lists
g_norm_32
,
norm_per_tensor
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[
g_all_32
],
True
)
g_norm_16
,
norm_per_tensor
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[
g_all
_16
]
,
True
)
print
(
"====after multi_tensor_l2norm"
)
if
len
(
g_all_32
)
>
0
:
g_norm_32
,
_
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[
g_all_32
],
False
)
if
len
(
g_all_16
)
>
0
:
g_norm
_16
,
_
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[
g_all_16
],
False
)
# blend two grad norms to get global grad norm
global_grad_norm
=
math
.
sqrt
(
g_norm_32
*
g_norm_32
+
g_norm_16
*
g_norm_16
)
max_grad_norm
=
self
.
defaults
[
'max_grad_norm'
]
print
(
"====global_grad_norm:"
,
global_grad_norm
)
print
(
"====max_grad_norm:"
,
max_grad_norm
)
for
group
in
self
.
param_groups
:
bias_correction
=
1
if
group
[
'bias_correction'
]
else
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment