Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
8ed8eaac
Commit
8ed8eaac
authored
May 30, 2020
by
Thor Johnsen
Browse files
Use correct names for mt lamb cuda kernels
parent
45388d48
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
27 deletions
+4
-27
apex/contrib/optimizers/distributed_fused_lamb.py
apex/contrib/optimizers/distributed_fused_lamb.py
+4
-27
No files found.
apex/contrib/optimizers/distributed_fused_lamb.py
View file @
8ed8eaac
...
@@ -72,8 +72,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
...
@@ -72,8 +72,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_num_ag_pg
=
0
,
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_num_ag_pg
=
0
,
e5m2_allgather
=
False
):
e5m2_allgather
=
False
):
global
fused_adam_cuda
global
fused_adam_cuda
fused_adam_cuda
=
importlib
.
import_module
(
"fused_adam_cuda"
)
distributed_lamb_cuda
=
importlib
.
import_module
(
"distributed_lamb_cuda"
)
# FIXME: Import multi_tensor_lamb_* kernels instead
self
.
_amp_scale_adjustment
=
amp_scale_adjustment
self
.
_amp_scale_adjustment
=
amp_scale_adjustment
...
@@ -90,13 +89,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
...
@@ -90,13 +89,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
_has_overflow
=
False
self
.
_has_overflow
=
False
self
.
multi_tensor_lamb_compute_update_term
=
distributed_lamb_cuda
.
multi_tensor_lamb_compute_update_term
self
.
multi_tensor_lamb_update_weights
=
distributed_lamb_cuda
.
multi_tensor_lamb_update_weights
import
amp_C
import
amp_C
self
.
multi_tensor_lamb_compute_update_term
=
amp_C
.
multi_tensor_distopt_lamb_
self
.
multi_tensor_l2norm
=
amp_C
.
multi_tensor_l2norm
import
amp_C
self
.
multi_tensor_l2norm
=
amp_C
.
multi_tensor_l2norm
# Skip buffer
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
multi_tensor_lamb
=
amp_C
.
multi_tensor_lamb
self
.
_last_step
=
False
self
.
_last_step
=
False
self
.
_overlap_reductions
=
overlap_reductions
self
.
_overlap_reductions
=
overlap_reductions
...
@@ -423,25 +419,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
...
@@ -423,25 +419,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
torch
.
distributed
.
all_reduce
(
l2_grad_norm_sq
,
group
=
self
.
_l2_grad_norm_pg
)
torch
.
distributed
.
all_reduce
(
l2_grad_norm_sq
,
group
=
self
.
_l2_grad_norm_pg
)
self
.
_L2_grad_norm
=
l2_grad_norm_sq
.
sqrt
().
item
()
self
.
_L2_grad_norm
=
l2_grad_norm_sq
.
sqrt
().
item
()
def
__launch_step_kernel
(
self
,
p
,
p_copy
,
m
,
v
,
g
):
combined_scale
=
self
.
_global_scale
if
self
.
_param_group
[
'max_grad_norm'
]
>
0
and
math
.
isfinite
(
self
.
L2_grad_norm
):
combined_scale
=
self
.
_param_group
[
'max_grad_norm'
]
/
(
self
.
L2_grad_norm
/
self
.
_global_scale
+
1e-6
)
combined_scale
=
self
.
_global_scale
/
min
(
1
,
combined_scale
)
bias_correction
=
1
if
self
.
_param_group
[
'bias_correction'
]
else
0
beta1
,
beta2
=
self
.
_param_group
[
'betas'
]
fused_adam_cuda
.
reversible_adam
(
p
,
p_copy
,
m
,
v
,
g
,
self
.
_param_group
[
'lr'
],
beta1
,
beta2
,
self
.
_param_group
[
'eps'
],
combined_scale
,
self
.
_param_state
[
'step'
]
+
1
,
self
.
eps_mode
,
bias_correction
,
self
.
_param_group
[
'weight_decay'
])
def
__compute_contrib_param_norm
(
self
):
def
__compute_contrib_param_norm
(
self
):
if
self
.
_contrib_model_param_for_norm_fp16
is
not
None
and
self
.
_contrib_model_param_for_norm_fp32
is
not
None
:
if
self
.
_contrib_model_param_for_norm_fp16
is
not
None
and
self
.
_contrib_model_param_for_norm_fp32
is
not
None
:
gnorm_fp16
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[
self
.
_contrib_model_param_for_norm_fp16
],
True
)
gnorm_fp16
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[
self
.
_contrib_model_param_for_norm_fp16
],
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment