Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
47e3367f
Commit
47e3367f
authored
Jun 11, 2019
by
Michael Carilli
Browse files
Allow multi_tensor_lamb to update fp16 params
parent
04667139
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
csrc/multi_tensor_lamb_stage_1.cu
csrc/multi_tensor_lamb_stage_1.cu
+3
-3
No files found.
csrc/multi_tensor_lamb_stage_1.cu
View file @
47e3367f
...
...
@@ -100,20 +100,20 @@ void multi_tensor_lamb_stage1_cuda(
float
beta1_correction
=
1.0
f
-
std
::
pow
(
beta1
,
next_step
);
float
beta2_correction
=
1.0
f
-
std
::
pow
(
beta2
,
next_step
);
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_1"
,
using
accscalar_t_0
=
acc_type
<
scalar_t_0
,
true
>
;
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"lamb_stage_1"
,
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
LAMBStage1Functor
<
scalar_t_0
,
acc
scalar_t_
0
>
(),
LAMBStage1Functor
<
scalar_t_0
,
scalar_t_
1
>
(),
per_tensor_decay
.
data
<
float
>
(),
beta1
,
beta2
,
beta1_correction
,
beta2_correction
,
epsilon
,
clipped_global_grad_norm
);
)
clipped_global_grad_norm
);
)
)
AT_CUDA_CHECK
(
cudaGetLastError
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment