Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
32d2c4e2
"vscode:/vscode.git/clone" did not exist on "bc80f1438c76c5dedd3147f8e1edc34e48eca528"
Commit
32d2c4e2
authored
Mar 31, 2020
by
Kexin Yu
Browse files
clip gradients globally, rather than per group
parent
8405d436
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
32 additions
and
6 deletions
+32
-6
apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp
apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp
+1
-0
apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu
apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu
+2
-4
apex/contrib/optimizers/fused_lamb.py
apex/contrib/optimizers/fused_lamb.py
+29
-2
No files found.
apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp
View file @
32d2c4e2
...
...
@@ -13,6 +13,7 @@ void multi_tensor_lamb_cuda(
const
float
weight_decay
,
const
int
grad_averaging
,
const
int
mode
,
const
float
global_grad_norm
,
const
float
max_grad_norm
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
...
...
apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu
View file @
32d2c4e2
...
...
@@ -227,6 +227,7 @@ void multi_tensor_lamb_cuda(
const
float
weight_decay
,
const
int
grad_averaging
,
const
int
mode
,
const
float
global_grad_norm
,
const
float
max_grad_norm
)
{
using
namespace
at
;
...
...
@@ -247,9 +248,6 @@ void multi_tensor_lamb_cuda(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
grad_list
(
tensor_lists
.
begin
(),
tensor_lists
.
begin
()
+
1
);
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
param_list
(
tensor_lists
.
begin
()
+
1
,
tensor_lists
.
begin
()
+
2
);
// Compute global grad norm
auto
grad_norm_tuple
=
multi_tensor_l2norm_cuda
(
chunk_size
,
noop_flag
,
grad_list
,
false
);
// Compute per tensor param norm
auto
param_norm_tuple
=
multi_tensor_l2norm_cuda
(
chunk_size
,
noop_flag
,
param_list
,
true
);
...
...
@@ -271,7 +269,7 @@ void multi_tensor_lamb_cuda(
epsilon
,
(
adamMode_t
)
mode
,
weight_decay
,
std
::
get
<
0
>
(
grad_norm_tuple
).
DATA_PTR
<
float
>
()
,
global_grad_norm
,
max_grad_norm
);
)
// Compute update norms
...
...
apex/contrib/optimizers/fused_lamb.py
View file @
32d2c4e2
import
torch
import
importlib
import
math
from
apex.multi_tensor_apply
import
multi_tensor_applier
class
FusedLAMB
(
torch
.
optim
.
Optimizer
):
...
...
@@ -100,6 +101,30 @@ class FusedLAMB(torch.optim.Optimizer):
if
closure
is
not
None
:
loss
=
closure
()
# create separate grad lists for fp32 and fp16 params
g_all_32
,
g_all_16
=
[],
[]
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
if
p
.
grad
is
not
None
:
if
p
.
dtype
==
torch
.
float32
:
g_all_32
.
append
(
p
.
grad
.
data
)
elif
p
.
dytpe
==
torch
.
float16
:
g_all_16
.
append
(
p
.
grad
.
data
)
else
:
raise
RuntimeError
(
'FusedLAMB only support fp16 and fp32.'
)
# compute grad norm for two lists
g_norm_32
,
_
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[
g_all_32
],
False
)
g_norm_16
,
_
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[
g_all_16
],
False
)
# blend two grad norms to get global grad norm
global_grad_norm
=
math
.
sqrt
(
g_norm_32
*
g_norm_32
+
g_norm_16
*
g_norm_16
)
max_grad_norm
=
self
.
defaults
[
'max_grad_norm'
]
for
group
in
self
.
param_groups
:
bias_correction
=
1
if
group
[
'bias_correction'
]
else
0
beta1
,
beta2
=
group
[
'betas'
]
...
...
@@ -156,7 +181,8 @@ class FusedLAMB(torch.optim.Optimizer):
group
[
'weight_decay'
],
grad_averaging
,
self
.
adam_w_mode
,
group
[
'max_grad_norm'
])
global_grad_norm
,
max_grad_norm
)
if
(
len
(
g_32
)
>
0
):
multi_tensor_applier
(
self
.
multi_tensor_lamb
,
self
.
_dummy_overflow_buf
,
...
...
@@ -170,6 +196,7 @@ class FusedLAMB(torch.optim.Optimizer):
group
[
'weight_decay'
],
grad_averaging
,
self
.
adam_w_mode
,
group
[
'max_grad_norm'
])
global_grad_norm
,
max_grad_norm
)
return
loss
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment