Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
8ed8eaac
"docs/vscode:/vscode.git/clone" did not exist on "3f8146a7733171bb769e37ca453f2a7974973ef8"
Commit
8ed8eaac
authored
May 30, 2020
by
Thor Johnsen
Browse files
Use correct names for mt lamb cuda kernels
parent
45388d48
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
27 deletions
+4
-27
apex/contrib/optimizers/distributed_fused_lamb.py
apex/contrib/optimizers/distributed_fused_lamb.py
+4
-27
No files found.
apex/contrib/optimizers/distributed_fused_lamb.py
View file @
8ed8eaac
...
...
@@ -72,8 +72,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_num_ag_pg
=
0
,
e5m2_allgather
=
False
):
global
fused_adam_cuda
fused_adam_cuda
=
importlib
.
import_module
(
"fused_adam_cuda"
)
# FIXME: Import multi_tensor_lamb_* kernels instead
distributed_lamb_cuda
=
importlib
.
import_module
(
"distributed_lamb_cuda"
)
self
.
_amp_scale_adjustment
=
amp_scale_adjustment
...
...
@@ -90,13 +89,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
_has_overflow
=
False
self
.
multi_tensor_lamb_compute_update_term
=
distributed_lamb_cuda
.
multi_tensor_lamb_compute_update_term
self
.
multi_tensor_lamb_update_weights
=
distributed_lamb_cuda
.
multi_tensor_lamb_update_weights
import
amp_C
self
.
multi_tensor_lamb_compute_update_term
=
amp_C
.
multi_tensor_distopt_lamb_
import
amp_C
self
.
multi_tensor_l2norm
=
amp_C
.
multi_tensor_l2norm
# Skip buffer
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
multi_tensor_lamb
=
amp_C
.
multi_tensor_lamb
self
.
multi_tensor_l2norm
=
amp_C
.
multi_tensor_l2norm
self
.
_last_step
=
False
self
.
_overlap_reductions
=
overlap_reductions
...
...
@@ -423,25 +419,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
torch
.
distributed
.
all_reduce
(
l2_grad_norm_sq
,
group
=
self
.
_l2_grad_norm_pg
)
self
.
_L2_grad_norm
=
l2_grad_norm_sq
.
sqrt
().
item
()
def
__launch_step_kernel
(
self
,
p
,
p_copy
,
m
,
v
,
g
):
combined_scale
=
self
.
_global_scale
if
self
.
_param_group
[
'max_grad_norm'
]
>
0
and
math
.
isfinite
(
self
.
L2_grad_norm
):
combined_scale
=
self
.
_param_group
[
'max_grad_norm'
]
/
(
self
.
L2_grad_norm
/
self
.
_global_scale
+
1e-6
)
combined_scale
=
self
.
_global_scale
/
min
(
1
,
combined_scale
)
bias_correction
=
1
if
self
.
_param_group
[
'bias_correction'
]
else
0
beta1
,
beta2
=
self
.
_param_group
[
'betas'
]
fused_adam_cuda
.
reversible_adam
(
p
,
p_copy
,
m
,
v
,
g
,
self
.
_param_group
[
'lr'
],
beta1
,
beta2
,
self
.
_param_group
[
'eps'
],
combined_scale
,
self
.
_param_state
[
'step'
]
+
1
,
self
.
eps_mode
,
bias_correction
,
self
.
_param_group
[
'weight_decay'
])
def
__compute_contrib_param_norm
(
self
):
if
self
.
_contrib_model_param_for_norm_fp16
is
not
None
and
self
.
_contrib_model_param_for_norm_fp32
is
not
None
:
gnorm_fp16
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[
self
.
_contrib_model_param_for_norm_fp16
],
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment