Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
3aeea0d8
Commit
3aeea0d8
authored
Jun 28, 2019
by
Thor Johnsen
Browse files
Add support for fp16 update term (new UPD_T typename in template)
parent
18f2eaee
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
22 deletions
+24
-22
csrc/multi_tensor_lamb_stage_1.cu
csrc/multi_tensor_lamb_stage_1.cu
+17
-16
csrc/multi_tensor_lamb_stage_2.cu
csrc/multi_tensor_lamb_stage_2.cu
+7
-6
No files found.
csrc/multi_tensor_lamb_stage_1.cu
View file @
3aeea0d8
...
...
@@ -14,7 +14,7 @@
#define ILP 4
// Step 1 computes the 'update' value of regular Adam optimizer.
template
<
typename
GRAD_T
,
typename
T
>
template
<
typename
GRAD_T
,
typename
T
,
typename
UPD_
T
>
struct
LAMBStage1Functor
{
__device__
__forceinline__
void
operator
()(
...
...
@@ -52,7 +52,7 @@ struct LAMBStage1Functor
T
*
v
=
(
T
*
)
tl
.
addresses
[
3
][
tensor_loc
];
v
+=
chunk_idx
*
chunk_size
;
T
*
update
=
(
T
*
)
tl
.
addresses
[
4
][
tensor_loc
];
UPD_
T
*
update
=
(
UPD_
T
*
)
tl
.
addresses
[
4
][
tensor_loc
];
update
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
...
@@ -100,7 +100,7 @@ struct LAMBStage1Functor
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
update
[
i
]
=
r_p
[
ii
];
update
[
i
]
=
(
UPD_T
)
r_p
[
ii
];
m
[
i
]
=
r_m
[
ii
];
v
[
i
]
=
r_v
[
ii
];
}
...
...
@@ -129,19 +129,20 @@ void multi_tensor_lamb_stage1_cuda(
float
beta2_correction
=
1.0
f
-
std
::
pow
(
beta2
,
next_step
);
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"lamb_stage_1"
,
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
LAMBStage1Functor
<
scalar_t_0
,
scalar_t_1
>
(),
per_tensor_decay
.
data
<
float
>
(),
beta1
,
beta2
,
beta1_correction
,
beta2_correction
,
epsilon
,
clipped_global_grad_norm
);
))
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
4
][
0
].
scalar_type
(),
2
,
"lamb_stage_1"
,
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
LAMBStage1Functor
<
scalar_t_0
,
scalar_t_1
,
scalar_t_2
>
(),
per_tensor_decay
.
data
<
float
>
(),
beta1
,
beta2
,
beta1_correction
,
beta2_correction
,
epsilon
,
clipped_global_grad_norm
);
)))
AT_CUDA_CHECK
(
cudaGetLastError
());
...
...
csrc/multi_tensor_lamb_stage_2.cu
View file @
3aeea0d8
...
...
@@ -15,7 +15,7 @@
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
template
<
typename
T
>
template
<
typename
T
,
typename
UPD_
T
>
struct
LAMBStage2Functor
{
__device__
__forceinline__
void
operator
()(
...
...
@@ -42,7 +42,7 @@ struct LAMBStage2Functor
T
*
p
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
T
*
update
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
UPD_
T
*
update
=
(
UPD_
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
update
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
...
@@ -52,7 +52,7 @@ struct LAMBStage2Functor
i_start
+=
blockDim
.
x
*
ILP
)
{
T
r_p
[
ILP
];
T
r_update
[
ILP
];
UPD_
T
r_update
[
ILP
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
...
...
@@ -66,7 +66,7 @@ struct LAMBStage2Functor
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
r_p
[
ii
]
=
r_p
[
ii
]
-
(
ratio
*
r_update
[
ii
]);
r_p
[
ii
]
=
r_p
[
ii
]
-
(
ratio
*
(
T
)
r_update
[
ii
]);
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
...
...
@@ -92,15 +92,16 @@ void multi_tensor_lamb_stage2_cuda(
using
namespace
at
;
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_2"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"lamb_stage_2"
,
multi_tensor_apply
<
2
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
LAMBStage2Functor
<
scalar_t_0
>
(),
LAMBStage2Functor
<
scalar_t_0
,
scalar_t_1
>
(),
per_tensor_param_norm
.
data
<
float
>
(),
per_tensor_update_norm
.
data
<
float
>
(),
learning_rate
);
)
learning_rate
);
)
)
AT_CUDA_CHECK
(
cudaGetLastError
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment