Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
10591ecd
Commit
10591ecd
authored
Apr 02, 2022
by
Sze-qq
Committed by
binmakeswell
Apr 06, 2022
Browse files
[NFC] polish colossalai/kernel/cuda_native/csrc/cpu_adam.cpp code style (#636)
parent
6fcb3818
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
388 additions
and
391 deletions
+388
-391
colossalai/kernel/cuda_native/csrc/cpu_adam.cpp
colossalai/kernel/cuda_native/csrc/cpu_adam.cpp
+388
-391
No files found.
colossalai/kernel/cuda_native/csrc/cpu_adam.cpp
View file @
10591ecd
...
@@ -20,29 +20,23 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
...
@@ -20,29 +20,23 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
SOFTWARE
*/
*/
#include "cpu_adam.h"
#include "cpu_adam.h"
#include <iostream>
#include <math.h>
#include <math.h>
#include <memory>
#include <omp.h>
#include <omp.h>
#include <string.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include <iostream>
#include <memory>
#include <type_traits>
#include <type_traits>
#include <unordered_map>
#include <unordered_map>
#include <string.h>
static
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
void
>>
s_optimizers
;
static
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
void
>>
s_optimizers
;
// C++ interface
// C++ interface
void
Adam_Optimizer
::
Step_1
(
float
*
_params
,
void
Adam_Optimizer
::
Step_1
(
float
*
_params
,
float
*
grads
,
float
*
_exp_avg
,
float
*
grads
,
float
*
_exp_avg_sq
,
size_t
_param_size
,
float
*
_exp_avg
,
bool
param_half_precision
,
bool
grad_half_precision
,
float
*
_exp_avg_sq
,
float
loss_scale
)
{
size_t
_param_size
,
bool
param_half_precision
,
bool
grad_half_precision
,
float
loss_scale
)
{
size_t
rounded_size
=
0
;
size_t
rounded_size
=
0
;
float
betta1_minus1
=
1
-
_betta1
;
float
betta1_minus1
=
1
-
_betta1
;
...
@@ -50,14 +44,14 @@ void Adam_Optimizer::Step_1(float* _params,
...
@@ -50,14 +44,14 @@ void Adam_Optimizer::Step_1(float* _params,
float
step_size
=
-
1
*
_alpha
/
_bias_correction1
;
float
step_size
=
-
1
*
_alpha
/
_bias_correction1
;
float
w_decay
=
-
1
*
_alpha
*
_weight_decay
;
float
w_decay
=
-
1
*
_alpha
*
_weight_decay
;
__half
*
params_cast_h
=
NULL
;
__half
*
params_cast_h
=
NULL
;
__half
*
grads_cast_h
=
NULL
;
__half
*
grads_cast_h
=
NULL
;
if
(
param_half_precision
)
{
if
(
param_half_precision
)
{
params_cast_h
=
reinterpret_cast
<
__half
*>
(
_params
);
params_cast_h
=
reinterpret_cast
<
__half
*>
(
_params
);
}
}
if
(
grad_half_precision
)
{
if
(
grad_half_precision
)
{
grads_cast_h
=
reinterpret_cast
<
__half
*>
(
grads
);
grads_cast_h
=
reinterpret_cast
<
__half
*>
(
grads
);
}
}
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
...
@@ -82,12 +76,14 @@ void Adam_Optimizer::Step_1(float* _params,
...
@@ -82,12 +76,14 @@ void Adam_Optimizer::Step_1(float* _params,
AVX_Data
weight_decay_4
;
AVX_Data
weight_decay_4
;
if
(
_weight_decay
>
0
)
if
(
_weight_decay
>
0
)
weight_decay_4
.
data
=
(
_adamw_mode
?
SIMD_SET
(
w_decay
)
:
SIMD_SET
(
_weight_decay
));
weight_decay_4
.
data
=
(
_adamw_mode
?
SIMD_SET
(
w_decay
)
:
SIMD_SET
(
_weight_decay
));
rounded_size
=
ROUND_DOWN
(
_param_size
,
SIMD_WIDTH
);
rounded_size
=
ROUND_DOWN
(
_param_size
,
SIMD_WIDTH
);
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
size_t
offset
=
copy_size
+
t
;
size_t
offset
=
copy_size
+
t
;
#pragma omp parallel for
#pragma omp parallel for
...
@@ -120,21 +116,24 @@ void Adam_Optimizer::Step_1(float* _params,
...
@@ -120,21 +116,24 @@ void Adam_Optimizer::Step_1(float* _params,
grad_4
.
data
=
SIMD_FMA
(
param_4
.
data
,
weight_decay_4
.
data
,
grad_4
.
data
);
grad_4
.
data
=
SIMD_FMA
(
param_4
.
data
,
weight_decay_4
.
data
,
grad_4
.
data
);
}
}
momentum_4
.
data
=
SIMD_MUL
(
momentum_4
.
data
,
betta1_4
.
data
);
momentum_4
.
data
=
SIMD_MUL
(
momentum_4
.
data
,
betta1_4
.
data
);
momentum_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
betta1_minus1_4
.
data
,
momentum_4
.
data
);
momentum_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
betta1_minus1_4
.
data
,
momentum_4
.
data
);
variance_4
.
data
=
SIMD_MUL
(
variance_4
.
data
,
betta2_4
.
data
);
variance_4
.
data
=
SIMD_MUL
(
variance_4
.
data
,
betta2_4
.
data
);
grad_4
.
data
=
SIMD_MUL
(
grad_4
.
data
,
grad_4
.
data
);
grad_4
.
data
=
SIMD_MUL
(
grad_4
.
data
,
grad_4
.
data
);
variance_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
betta2_minus1_4
.
data
,
variance_4
.
data
);
variance_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
betta2_minus1_4
.
data
,
variance_4
.
data
);
grad_4
.
data
=
SIMD_SQRT
(
variance_4
.
data
);
grad_4
.
data
=
SIMD_SQRT
(
variance_4
.
data
);
grad_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
.
data
=
SIMD_DIV
(
momentum_4
.
data
,
grad_4
.
data
);
grad_4
.
data
=
SIMD_DIV
(
momentum_4
.
data
,
grad_4
.
data
);
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
param_4
.
data
=
SIMD_FMA
(
param_4
.
data
,
weight_decay_4
.
data
,
param_4
.
data
);
param_4
.
data
=
SIMD_FMA
(
param_4
.
data
,
weight_decay_4
.
data
,
param_4
.
data
);
}
}
param_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
step_size_4
.
data
,
param_4
.
data
);
param_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
step_size_4
.
data
,
param_4
.
data
);
if
(
param_half_precision
)
{
if
(
param_half_precision
)
{
SIMD_STORE_HALF
((
float
*
)(
params_cast_h
+
i
),
param_4
.
data
);
SIMD_STORE_HALF
((
float
*
)(
params_cast_h
+
i
),
param_4
.
data
);
}
else
{
}
else
{
SIMD_STORE
(
_params
+
i
,
param_4
.
data
);
SIMD_STORE
(
_params
+
i
,
param_4
.
data
);
}
}
...
@@ -146,17 +145,23 @@ void Adam_Optimizer::Step_1(float* _params,
...
@@ -146,17 +145,23 @@ void Adam_Optimizer::Step_1(float* _params,
if
(
_param_size
>
rounded_size
)
{
if
(
_param_size
>
rounded_size
)
{
for
(
size_t
t
=
rounded_size
;
t
<
_param_size
;
t
+=
TILE
)
{
for
(
size_t
t
=
rounded_size
;
t
<
_param_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
_param_size
)
copy_size
=
_param_size
-
t
;
if
((
t
+
TILE
)
>
_param_size
)
copy_size
=
_param_size
-
t
;
size_t
offset
=
copy_size
+
t
;
size_t
offset
=
copy_size
+
t
;
#pragma omp parallel for
#pragma omp parallel for
for
(
size_t
k
=
t
;
k
<
offset
;
k
++
)
{
for
(
size_t
k
=
t
;
k
<
offset
;
k
++
)
{
float
grad
=
grad_half_precision
?
(
float
)
grads_cast_h
[
k
]
:
grads
[
k
];
float
grad
=
grad_half_precision
?
(
float
)
grads_cast_h
[
k
]
:
grads
[
k
];
if
(
loss_scale
>
0
)
{
grad
/=
loss_scale
;
}
if
(
loss_scale
>
0
)
{
float
param
=
param_half_precision
?
(
float
)
params_cast_h
[
k
]
:
_params
[
k
];
grad
/=
loss_scale
;
}
float
param
=
param_half_precision
?
(
float
)
params_cast_h
[
k
]
:
_params
[
k
];
float
momentum
=
_exp_avg
[
k
];
float
momentum
=
_exp_avg
[
k
];
float
variance
=
_exp_avg_sq
[
k
];
float
variance
=
_exp_avg_sq
[
k
];
if
(
_weight_decay
>
0
&&
!
_adamw_mode
)
{
grad
=
param
*
_weight_decay
+
grad
;
}
if
(
_weight_decay
>
0
&&
!
_adamw_mode
)
{
grad
=
param
*
_weight_decay
+
grad
;
}
momentum
=
momentum
*
_betta1
;
momentum
=
momentum
*
_betta1
;
momentum
=
grad
*
betta1_minus1
+
momentum
;
momentum
=
grad
*
betta1_minus1
+
momentum
;
...
@@ -167,7 +172,9 @@ void Adam_Optimizer::Step_1(float* _params,
...
@@ -167,7 +172,9 @@ void Adam_Optimizer::Step_1(float* _params,
grad
=
sqrt
(
variance
);
grad
=
sqrt
(
variance
);
grad
=
grad
*
_bias_correction2
+
_eps
;
grad
=
grad
*
_bias_correction2
+
_eps
;
grad
=
momentum
/
grad
;
grad
=
momentum
/
grad
;
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
param
+=
w_decay
*
param
;
}
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
param
+=
w_decay
*
param
;
}
param
=
grad
*
step_size
+
param
;
param
=
grad
*
step_size
+
param
;
if
(
param_half_precision
)
if
(
param_half_precision
)
...
@@ -181,24 +188,19 @@ void Adam_Optimizer::Step_1(float* _params,
...
@@ -181,24 +188,19 @@ void Adam_Optimizer::Step_1(float* _params,
}
}
}
}
void
Adam_Optimizer
::
Step_4
(
float
*
_params
,
void
Adam_Optimizer
::
Step_4
(
float
*
_params
,
float
*
grads
,
float
*
_exp_avg
,
float
*
grads
,
float
*
_exp_avg_sq
,
size_t
_param_size
,
float
*
_exp_avg
,
bool
param_half_precision
,
bool
grad_half_precision
,
float
*
_exp_avg_sq
,
float
loss_scale
)
{
size_t
_param_size
,
bool
param_half_precision
,
bool
grad_half_precision
,
float
loss_scale
)
{
size_t
rounded_size
=
0
;
size_t
rounded_size
=
0
;
__half
*
params_cast_h
=
NULL
;
__half
*
params_cast_h
=
NULL
;
__half
*
grads_cast_h
=
NULL
;
__half
*
grads_cast_h
=
NULL
;
if
(
param_half_precision
)
{
if
(
param_half_precision
)
{
params_cast_h
=
reinterpret_cast
<
__half
*>
(
_params
);
params_cast_h
=
reinterpret_cast
<
__half
*>
(
_params
);
}
}
if
(
grad_half_precision
)
{
if
(
grad_half_precision
)
{
grads_cast_h
=
reinterpret_cast
<
__half
*>
(
grads
);
grads_cast_h
=
reinterpret_cast
<
__half
*>
(
grads
);
}
}
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
...
@@ -227,12 +229,14 @@ void Adam_Optimizer::Step_4(float* _params,
...
@@ -227,12 +229,14 @@ void Adam_Optimizer::Step_4(float* _params,
float
w_decay
=
-
1
*
_alpha
*
_weight_decay
;
float
w_decay
=
-
1
*
_alpha
*
_weight_decay
;
AVX_Data
weight_decay_4
;
AVX_Data
weight_decay_4
;
if
(
_weight_decay
>
0
)
if
(
_weight_decay
>
0
)
weight_decay_4
.
data
=
(
_adamw_mode
?
SIMD_SET
(
w_decay
)
:
SIMD_SET
(
_weight_decay
));
weight_decay_4
.
data
=
(
_adamw_mode
?
SIMD_SET
(
w_decay
)
:
SIMD_SET
(
_weight_decay
));
rounded_size
=
ROUND_DOWN
(
_param_size
,
SIMD_WIDTH
*
4
);
rounded_size
=
ROUND_DOWN
(
_param_size
,
SIMD_WIDTH
*
4
);
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
size_t
offset
=
copy_size
+
t
;
size_t
offset
=
copy_size
+
t
;
#pragma omp parallel for
#pragma omp parallel for
...
@@ -249,7 +253,7 @@ void Adam_Optimizer::Step_4(float* _params,
...
@@ -249,7 +253,7 @@ void Adam_Optimizer::Step_4(float* _params,
grad_4
[
j
].
data
=
SIMD_LOAD
(
grads
+
i
+
SIMD_WIDTH
*
j
);
grad_4
[
j
].
data
=
SIMD_LOAD
(
grads
+
i
+
SIMD_WIDTH
*
j
);
}
}
if
(
loss_scale
>
0
)
{
if
(
loss_scale
>
0
)
{
AVX_Data
loss_scale_vec
;
AVX_Data
loss_scale_vec
;
loss_scale_vec
.
data
=
SIMD_SET
(
loss_scale
);
loss_scale_vec
.
data
=
SIMD_SET
(
loss_scale
);
grad_4
[
j
].
data
=
SIMD_DIV
(
grad_4
[
j
].
data
,
loss_scale_vec
.
data
);
grad_4
[
j
].
data
=
SIMD_DIV
(
grad_4
[
j
].
data
,
loss_scale_vec
.
data
);
...
@@ -265,23 +269,29 @@ void Adam_Optimizer::Step_4(float* _params,
...
@@ -265,23 +269,29 @@ void Adam_Optimizer::Step_4(float* _params,
}
}
if
(
_weight_decay
>
0
&&
!
_adamw_mode
)
{
if
(
_weight_decay
>
0
&&
!
_adamw_mode
)
{
grad_4
[
j
].
data
=
SIMD_FMA
(
param_4
[
j
].
data
,
weight_decay_4
.
data
,
grad_4
[
j
].
data
);
grad_4
[
j
].
data
=
SIMD_FMA
(
param_4
[
j
].
data
,
weight_decay_4
.
data
,
grad_4
[
j
].
data
);
}
}
momentum_4
[
j
].
data
=
SIMD_MUL
(
momentum_4
[
j
].
data
,
betta1_4
.
data
);
momentum_4
[
j
].
data
=
SIMD_MUL
(
momentum_4
[
j
].
data
,
betta1_4
.
data
);
momentum_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
j
].
data
);
momentum_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
j
].
data
);
variance_4
[
j
].
data
=
SIMD_MUL
(
variance_4
[
j
].
data
,
betta2_4
.
data
);
variance_4
[
j
].
data
=
SIMD_MUL
(
variance_4
[
j
].
data
,
betta2_4
.
data
);
grad_4
[
j
].
data
=
SIMD_MUL
(
grad_4
[
j
].
data
,
grad_4
[
j
].
data
);
grad_4
[
j
].
data
=
SIMD_MUL
(
grad_4
[
j
].
data
,
grad_4
[
j
].
data
);
variance_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
j
].
data
);
variance_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
j
].
data
);
grad_4
[
j
].
data
=
SIMD_SQRT
(
variance_4
[
j
].
data
);
grad_4
[
j
].
data
=
SIMD_SQRT
(
variance_4
[
j
].
data
);
grad_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
j
].
data
=
SIMD_DIV
(
momentum_4
[
j
].
data
,
grad_4
[
j
].
data
);
grad_4
[
j
].
data
=
SIMD_DIV
(
momentum_4
[
j
].
data
,
grad_4
[
j
].
data
);
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
param_4
[
j
].
data
=
SIMD_FMA
(
param_4
[
j
].
data
,
weight_decay_4
.
data
,
param_4
[
j
].
data
);
param_4
[
j
].
data
=
SIMD_FMA
(
param_4
[
j
].
data
,
weight_decay_4
.
data
,
param_4
[
j
].
data
);
}
}
param_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
step_size_4
.
data
,
param_4
[
j
].
data
);
param_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
step_size_4
.
data
,
param_4
[
j
].
data
);
if
(
param_half_precision
)
{
if
(
param_half_precision
)
{
SIMD_STORE_HALF
((
float
*
)(
params_cast_h
+
i
+
SIMD_WIDTH
*
j
),
param_4
[
j
].
data
);
SIMD_STORE_HALF
((
float
*
)(
params_cast_h
+
i
+
SIMD_WIDTH
*
j
),
param_4
[
j
].
data
);
}
else
{
}
else
{
SIMD_STORE
(
_params
+
i
+
SIMD_WIDTH
*
j
,
param_4
[
j
].
data
);
SIMD_STORE
(
_params
+
i
+
SIMD_WIDTH
*
j
,
param_4
[
j
].
data
);
}
}
...
@@ -292,31 +302,25 @@ void Adam_Optimizer::Step_4(float* _params,
...
@@ -292,31 +302,25 @@ void Adam_Optimizer::Step_4(float* _params,
}
}
#endif
#endif
if
(
_param_size
>
rounded_size
)
if
(
_param_size
>
rounded_size
)
Step_1
((
param_half_precision
?
(
float
*
)(
params_cast_h
+
rounded_size
)
:
_params
+
rounded_size
),
Step_1
((
param_half_precision
?
(
float
*
)(
params_cast_h
+
rounded_size
)
(
grad_half_precision
?
(
float
*
)(
grads_cast_h
+
rounded_size
)
:
grads
+
rounded_size
),
:
_params
+
rounded_size
),
(
_exp_avg
+
rounded_size
),
(
grad_half_precision
?
(
float
*
)(
grads_cast_h
+
rounded_size
)
(
_exp_avg_sq
+
rounded_size
),
:
grads
+
rounded_size
),
(
_param_size
-
rounded_size
),
(
_exp_avg
+
rounded_size
),
(
_exp_avg_sq
+
rounded_size
),
param_half_precision
,
(
_param_size
-
rounded_size
),
param_half_precision
,
grad_half_precision
,
grad_half_precision
,
loss_scale
);
loss_scale
);
}
}
int
create_adam_optimizer
(
int
optimizer_id
,
int
create_adam_optimizer
(
int
optimizer_id
,
float
alpha
=
1e-3
,
float
alpha
=
1e-3
,
float
betta1
=
0.9
,
float
betta2
=
0.999
,
float
betta1
=
0.9
,
float
eps
=
1e-8
,
float
weight_decay
=
0
,
float
betta2
=
0.999
,
bool
adamw_mode
=
true
,
bool
should_log
=
false
)
{
float
eps
=
1e-8
,
auto
opt
=
std
::
make_shared
<
Adam_Optimizer
>
(
alpha
,
betta1
,
betta2
,
eps
,
float
weight_decay
=
0
,
weight_decay
,
adamw_mode
);
bool
adamw_mode
=
true
,
bool
should_log
=
false
)
{
auto
opt
=
std
::
make_shared
<
Adam_Optimizer
>
(
alpha
,
betta1
,
betta2
,
eps
,
weight_decay
,
adamw_mode
);
s_optimizers
[
optimizer_id
]
=
opt
;
s_optimizers
[
optimizer_id
]
=
opt
;
if
(
should_log
){
if
(
should_log
)
{
std
::
string
avx_type
=
""
;
std
::
string
avx_type
=
""
;
#if defined(__AVX512__)
#if defined(__AVX512__)
...
@@ -329,36 +333,26 @@ int create_adam_optimizer(int optimizer_id,
...
@@ -329,36 +333,26 @@ int create_adam_optimizer(int optimizer_id,
#endif
#endif
#endif
#endif
printf
(
"Adam Optimizer #%d is created with %s arithmetic capability.
\n
"
,
printf
(
"Adam Optimizer #%d is created with %s arithmetic capability.
\n
"
,
optimizer_id
,
optimizer_id
,
avx_type
.
c_str
());
avx_type
.
c_str
());
printf
(
"Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d
\n
"
,
printf
(
"Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d
\n
"
,
alpha
,
alpha
,
betta1
,
betta2
,
weight_decay
,
(
int
)
adamw_mode
);
betta1
,
betta2
,
weight_decay
,
(
int
)
adamw_mode
);
}
}
return
0
;
return
0
;
}
}
void
Adam_Optimizer
::
Step_8
(
float
*
_params
,
void
Adam_Optimizer
::
Step_8
(
float
*
_params
,
float
*
grads
,
float
*
_exp_avg
,
float
*
grads
,
float
*
_exp_avg_sq
,
size_t
_param_size
,
float
*
_exp_avg
,
bool
param_half_precision
,
bool
grad_half_precision
,
float
*
_exp_avg_sq
,
float
loss_scale
)
{
size_t
_param_size
,
bool
param_half_precision
,
bool
grad_half_precision
,
float
loss_scale
)
{
size_t
rounded_size
=
0
;
size_t
rounded_size
=
0
;
__half
*
params_cast_h
=
NULL
;
__half
*
params_cast_h
=
NULL
;
__half
*
grads_cast_h
=
NULL
;
__half
*
grads_cast_h
=
NULL
;
if
(
param_half_precision
)
{
if
(
param_half_precision
)
{
params_cast_h
=
reinterpret_cast
<
__half
*>
(
_params
);
params_cast_h
=
reinterpret_cast
<
__half
*>
(
_params
);
}
}
if
(
grad_half_precision
)
{
if
(
grad_half_precision
)
{
grads_cast_h
=
reinterpret_cast
<
__half
*>
(
grads
);
grads_cast_h
=
reinterpret_cast
<
__half
*>
(
grads
);
}
}
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data
betta1_4
;
AVX_Data
betta1_4
;
...
@@ -386,12 +380,14 @@ void Adam_Optimizer::Step_8(float* _params,
...
@@ -386,12 +380,14 @@ void Adam_Optimizer::Step_8(float* _params,
float
w_decay
=
-
1
*
_alpha
*
_weight_decay
;
float
w_decay
=
-
1
*
_alpha
*
_weight_decay
;
AVX_Data
weight_decay_4
;
AVX_Data
weight_decay_4
;
if
(
_weight_decay
>
0
)
if
(
_weight_decay
>
0
)
weight_decay_4
.
data
=
(
_adamw_mode
?
SIMD_SET
(
w_decay
)
:
SIMD_SET
(
_weight_decay
));
weight_decay_4
.
data
=
(
_adamw_mode
?
SIMD_SET
(
w_decay
)
:
SIMD_SET
(
_weight_decay
));
rounded_size
=
ROUND_DOWN
(
_param_size
,
SIMD_WIDTH
*
8
);
rounded_size
=
ROUND_DOWN
(
_param_size
,
SIMD_WIDTH
*
8
);
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
size_t
offset
=
copy_size
+
t
;
size_t
offset
=
copy_size
+
t
;
#pragma omp parallel for
#pragma omp parallel for
...
@@ -424,23 +420,29 @@ void Adam_Optimizer::Step_8(float* _params,
...
@@ -424,23 +420,29 @@ void Adam_Optimizer::Step_8(float* _params,
}
}
if
(
_weight_decay
>
0
&&
!
_adamw_mode
)
{
if
(
_weight_decay
>
0
&&
!
_adamw_mode
)
{
grad_4
[
j
].
data
=
SIMD_FMA
(
param_4
[
j
].
data
,
weight_decay_4
.
data
,
grad_4
[
j
].
data
);
grad_4
[
j
].
data
=
SIMD_FMA
(
param_4
[
j
].
data
,
weight_decay_4
.
data
,
grad_4
[
j
].
data
);
}
}
momentum_4
[
j
].
data
=
SIMD_MUL
(
momentum_4
[
j
].
data
,
betta1_4
.
data
);
momentum_4
[
j
].
data
=
SIMD_MUL
(
momentum_4
[
j
].
data
,
betta1_4
.
data
);
momentum_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
j
].
data
);
momentum_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
j
].
data
);
variance_4
[
j
].
data
=
SIMD_MUL
(
variance_4
[
j
].
data
,
betta2_4
.
data
);
variance_4
[
j
].
data
=
SIMD_MUL
(
variance_4
[
j
].
data
,
betta2_4
.
data
);
grad_4
[
j
].
data
=
SIMD_MUL
(
grad_4
[
j
].
data
,
grad_4
[
j
].
data
);
grad_4
[
j
].
data
=
SIMD_MUL
(
grad_4
[
j
].
data
,
grad_4
[
j
].
data
);
variance_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
j
].
data
);
variance_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
j
].
data
);
grad_4
[
j
].
data
=
SIMD_SQRT
(
variance_4
[
j
].
data
);
grad_4
[
j
].
data
=
SIMD_SQRT
(
variance_4
[
j
].
data
);
grad_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
j
].
data
=
SIMD_DIV
(
momentum_4
[
j
].
data
,
grad_4
[
j
].
data
);
grad_4
[
j
].
data
=
SIMD_DIV
(
momentum_4
[
j
].
data
,
grad_4
[
j
].
data
);
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
param_4
[
j
].
data
=
SIMD_FMA
(
param_4
[
j
].
data
,
weight_decay_4
.
data
,
param_4
[
j
].
data
);
param_4
[
j
].
data
=
SIMD_FMA
(
param_4
[
j
].
data
,
weight_decay_4
.
data
,
param_4
[
j
].
data
);
}
}
param_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
step_size_4
.
data
,
param_4
[
j
].
data
);
param_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
step_size_4
.
data
,
param_4
[
j
].
data
);
if
(
param_half_precision
)
{
if
(
param_half_precision
)
{
SIMD_STORE_HALF
((
float
*
)(
params_cast_h
+
i
+
SIMD_WIDTH
*
j
),
param_4
[
j
].
data
);
SIMD_STORE_HALF
((
float
*
)(
params_cast_h
+
i
+
SIMD_WIDTH
*
j
),
param_4
[
j
].
data
);
}
else
{
}
else
{
SIMD_STORE
(
_params
+
i
+
SIMD_WIDTH
*
j
,
param_4
[
j
].
data
);
SIMD_STORE
(
_params
+
i
+
SIMD_WIDTH
*
j
,
param_4
[
j
].
data
);
}
}
...
@@ -452,14 +454,13 @@ void Adam_Optimizer::Step_8(float* _params,
...
@@ -452,14 +454,13 @@ void Adam_Optimizer::Step_8(float* _params,
}
}
#endif
#endif
if
(
_param_size
>
rounded_size
)
if
(
_param_size
>
rounded_size
)
Step_4
((
param_half_precision
?
(
float
*
)(
params_cast_h
+
rounded_size
)
:
_params
+
rounded_size
),
Step_4
((
param_half_precision
?
(
float
*
)(
params_cast_h
+
rounded_size
)
(
grad_half_precision
?
(
float
*
)(
grads_cast_h
+
rounded_size
)
:
grads
+
rounded_size
),
:
_params
+
rounded_size
),
(
_exp_avg
+
rounded_size
),
(
grad_half_precision
?
(
float
*
)(
grads_cast_h
+
rounded_size
)
(
_exp_avg_sq
+
rounded_size
),
:
grads
+
rounded_size
),
(
_param_size
-
rounded_size
),
(
_exp_avg
+
rounded_size
),
(
_exp_avg_sq
+
rounded_size
),
param_half_precision
,
(
_param_size
-
rounded_size
),
param_half_precision
,
grad_half_precision
,
grad_half_precision
,
loss_scale
);
loss_scale
);
}
}
int
adam_step
(
int
optimizer_id
,
int
adam_step
(
int
optimizer_id
,
...
@@ -501,16 +502,12 @@ int adam_step(int optimizer_id,
...
@@ -501,16 +502,12 @@ int adam_step(int optimizer_id,
return
0
;
return
0
;
}
}
int
destroy_adam_optimizer
(
int
optimizer_id
)
{
int
destroy_adam_optimizer
(
int
optimizer_id
)
{
s_optimizers
.
erase
(
optimizer_id
);
s_optimizers
.
erase
(
optimizer_id
);
return
0
;
return
0
;
}
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
{
m
.
def
(
"adam_update"
,
&
adam_step
,
"CPU Adam update (C++)"
);
m
.
def
(
"adam_update"
,
&
adam_step
,
"CPU Adam update (C++)"
);
m
.
def
(
"create_adam"
,
&
create_adam_optimizer
,
"CPU Adam (C++)"
);
m
.
def
(
"create_adam"
,
&
create_adam_optimizer
,
"CPU Adam (C++)"
);
m
.
def
(
"destroy_adam"
,
&
destroy_adam_optimizer
,
"CPU Adam destroy (C++)"
);
m
.
def
(
"destroy_adam"
,
&
destroy_adam_optimizer
,
"CPU Adam destroy (C++)"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment