Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
32d2c4e2
Commit
32d2c4e2
authored
Mar 31, 2020
by
Kexin Yu
Browse files
clip gradients globally, rather than per group
parent
8405d436
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
32 additions
and
6 deletions
+32
-6
apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp
apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp
+1
-0
apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu
apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu
+2
-4
apex/contrib/optimizers/fused_lamb.py
apex/contrib/optimizers/fused_lamb.py
+29
-2
No files found.
apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp
View file @
32d2c4e2
...
@@ -13,6 +13,7 @@ void multi_tensor_lamb_cuda(
...
@@ -13,6 +13,7 @@ void multi_tensor_lamb_cuda(
const
float
weight_decay
,
const
float
weight_decay
,
const
int
grad_averaging
,
const
int
grad_averaging
,
const
int
mode
,
const
int
mode
,
const
float
global_grad_norm
,
const
float
max_grad_norm
);
const
float
max_grad_norm
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
...
...
apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu
View file @
32d2c4e2
...
@@ -227,6 +227,7 @@ void multi_tensor_lamb_cuda(
...
@@ -227,6 +227,7 @@ void multi_tensor_lamb_cuda(
const
float
weight_decay
,
const
float
weight_decay
,
const
int
grad_averaging
,
const
int
grad_averaging
,
const
int
mode
,
const
int
mode
,
const
float
global_grad_norm
,
const
float
max_grad_norm
)
const
float
max_grad_norm
)
{
{
using
namespace
at
;
using
namespace
at
;
...
@@ -247,9 +248,6 @@ void multi_tensor_lamb_cuda(
...
@@ -247,9 +248,6 @@ void multi_tensor_lamb_cuda(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
grad_list
(
tensor_lists
.
begin
(),
tensor_lists
.
begin
()
+
1
);
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
grad_list
(
tensor_lists
.
begin
(),
tensor_lists
.
begin
()
+
1
);
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
param_list
(
tensor_lists
.
begin
()
+
1
,
tensor_lists
.
begin
()
+
2
);
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
param_list
(
tensor_lists
.
begin
()
+
1
,
tensor_lists
.
begin
()
+
2
);
// Compute global grad norm
auto
grad_norm_tuple
=
multi_tensor_l2norm_cuda
(
chunk_size
,
noop_flag
,
grad_list
,
false
);
// Compute per tensor param norm
// Compute per tensor param norm
auto
param_norm_tuple
=
multi_tensor_l2norm_cuda
(
chunk_size
,
noop_flag
,
param_list
,
true
);
auto
param_norm_tuple
=
multi_tensor_l2norm_cuda
(
chunk_size
,
noop_flag
,
param_list
,
true
);
...
@@ -271,7 +269,7 @@ void multi_tensor_lamb_cuda(
...
@@ -271,7 +269,7 @@ void multi_tensor_lamb_cuda(
epsilon
,
epsilon
,
(
adamMode_t
)
mode
,
(
adamMode_t
)
mode
,
weight_decay
,
weight_decay
,
std
::
get
<
0
>
(
grad_norm_tuple
).
DATA_PTR
<
float
>
()
,
global_grad_norm
,
max_grad_norm
);
)
max_grad_norm
);
)
// Compute update norms
// Compute update norms
...
...
apex/contrib/optimizers/fused_lamb.py
View file @
32d2c4e2
import
torch
import
torch
import
importlib
import
importlib
import
math
from
apex.multi_tensor_apply
import
multi_tensor_applier
from
apex.multi_tensor_apply
import
multi_tensor_applier
class
FusedLAMB
(
torch
.
optim
.
Optimizer
):
class
FusedLAMB
(
torch
.
optim
.
Optimizer
):
...
@@ -100,6 +101,30 @@ class FusedLAMB(torch.optim.Optimizer):
...
@@ -100,6 +101,30 @@ class FusedLAMB(torch.optim.Optimizer):
if
closure
is
not
None
:
if
closure
is
not
None
:
loss
=
closure
()
loss
=
closure
()
# create separate grad lists for fp32 and fp16 params
g_all_32
,
g_all_16
=
[],
[]
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
if
p
.
grad
is
not
None
:
if
p
.
dtype
==
torch
.
float32
:
g_all_32
.
append
(
p
.
grad
.
data
)
elif
p
.
dytpe
==
torch
.
float16
:
g_all_16
.
append
(
p
.
grad
.
data
)
else
:
raise
RuntimeError
(
'FusedLAMB only support fp16 and fp32.'
)
# compute grad norm for two lists
g_norm_32
,
_
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[
g_all_32
],
False
)
g_norm_16
,
_
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[
g_all_16
],
False
)
# blend two grad norms to get global grad norm
global_grad_norm
=
math
.
sqrt
(
g_norm_32
*
g_norm_32
+
g_norm_16
*
g_norm_16
)
max_grad_norm
=
self
.
defaults
[
'max_grad_norm'
]
for
group
in
self
.
param_groups
:
for
group
in
self
.
param_groups
:
bias_correction
=
1
if
group
[
'bias_correction'
]
else
0
bias_correction
=
1
if
group
[
'bias_correction'
]
else
0
beta1
,
beta2
=
group
[
'betas'
]
beta1
,
beta2
=
group
[
'betas'
]
...
@@ -156,7 +181,8 @@ class FusedLAMB(torch.optim.Optimizer):
...
@@ -156,7 +181,8 @@ class FusedLAMB(torch.optim.Optimizer):
group
[
'weight_decay'
],
group
[
'weight_decay'
],
grad_averaging
,
grad_averaging
,
self
.
adam_w_mode
,
self
.
adam_w_mode
,
group
[
'max_grad_norm'
])
global_grad_norm
,
max_grad_norm
)
if
(
len
(
g_32
)
>
0
):
if
(
len
(
g_32
)
>
0
):
multi_tensor_applier
(
self
.
multi_tensor_lamb
,
multi_tensor_applier
(
self
.
multi_tensor_lamb
,
self
.
_dummy_overflow_buf
,
self
.
_dummy_overflow_buf
,
...
@@ -170,6 +196,7 @@ class FusedLAMB(torch.optim.Optimizer):
...
@@ -170,6 +196,7 @@ class FusedLAMB(torch.optim.Optimizer):
group
[
'weight_decay'
],
group
[
'weight_decay'
],
grad_averaging
,
grad_averaging
,
self
.
adam_w_mode
,
self
.
adam_w_mode
,
group
[
'max_grad_norm'
])
global_grad_norm
,
max_grad_norm
)
return
loss
return
loss
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment