Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
10591ecd
Commit
10591ecd
authored
Apr 02, 2022
by
Sze-qq
Committed by
binmakeswell
Apr 06, 2022
Browse files
[NFC] polish colossalai/kernel/cuda_native/csrc/cpu_adam.cpp code style (#636)
parent
6fcb3818
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
388 additions
and
391 deletions
+388
-391
colossalai/kernel/cuda_native/csrc/cpu_adam.cpp
colossalai/kernel/cuda_native/csrc/cpu_adam.cpp
+388
-391
No files found.
colossalai/kernel/cuda_native/csrc/cpu_adam.cpp
View file @
10591ecd
...
...
@@ -20,446 +20,447 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
*/
#include "cpu_adam.h"
#include <iostream>
#include <math.h>
#include <memory>
#include <omp.h>
#include <string.h>
#include <torch/extension.h>
#include <iostream>
#include <memory>
#include <type_traits>
#include <unordered_map>
#include <string.h>
static
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
void
>>
s_optimizers
;
// C++ interface
void
Adam_Optimizer
::
Step_1
(
float
*
_params
,
float
*
grads
,
float
*
_exp_avg
,
float
*
_exp_avg_sq
,
size_t
_param_size
,
bool
param_half_precision
,
bool
grad_half_precision
,
float
loss_scale
)
{
size_t
rounded_size
=
0
;
void
Adam_Optimizer
::
Step_1
(
float
*
_params
,
float
*
grads
,
float
*
_exp_avg
,
float
*
_exp_avg_sq
,
size_t
_param_size
,
bool
param_half_precision
,
bool
grad_half_precision
,
float
loss_scale
)
{
size_t
rounded_size
=
0
;
float
betta1_minus1
=
1
-
_betta1
;
float
betta2_minus1
=
1
-
_betta2
;
float
step_size
=
-
1
*
_alpha
/
_bias_correction1
;
float
w_decay
=
-
1
*
_alpha
*
_weight_decay
;
float
betta1_minus1
=
1
-
_betta1
;
float
betta2_minus1
=
1
-
_betta2
;
float
step_size
=
-
1
*
_alpha
/
_bias_correction1
;
float
w_decay
=
-
1
*
_alpha
*
_weight_decay
;
__half
*
params_cast_h
=
NULL
;
__half
*
grads_cast_h
=
NULL
;
__half
*
params_cast_h
=
NULL
;
__half
*
grads_cast_h
=
NULL
;
if
(
param_half_precision
)
{
params_cast_h
=
reinterpret_cast
<
__half
*>
(
_params
);
}
if
(
grad_half_precision
)
{
grads_cast_h
=
reinterpret_cast
<
__half
*>
(
grads
);
}
if
(
param_half_precision
)
{
params_cast_h
=
reinterpret_cast
<
__half
*>
(
_params
);
}
if
(
grad_half_precision
)
{
grads_cast_h
=
reinterpret_cast
<
__half
*>
(
grads
);
}
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data
betta1_4
;
betta1_4
.
data
=
SIMD_SET
(
_betta1
);
AVX_Data
betta2_4
;
betta2_4
.
data
=
SIMD_SET
(
_betta2
);
AVX_Data
betta1_4
;
betta1_4
.
data
=
SIMD_SET
(
_betta1
);
AVX_Data
betta2_4
;
betta2_4
.
data
=
SIMD_SET
(
_betta2
);
AVX_Data
betta1_minus1_4
;
betta1_minus1_4
.
data
=
SIMD_SET
(
betta1_minus1
);
AVX_Data
betta2_minus1_4
;
betta2_minus1_4
.
data
=
SIMD_SET
(
betta2_minus1
);
AVX_Data
betta1_minus1_4
;
betta1_minus1_4
.
data
=
SIMD_SET
(
betta1_minus1
);
AVX_Data
betta2_minus1_4
;
betta2_minus1_4
.
data
=
SIMD_SET
(
betta2_minus1
);
AVX_Data
bias2_sqrt
;
bias2_sqrt
.
data
=
SIMD_SET
(
_bias_correction2
);
AVX_Data
bias2_sqrt
;
bias2_sqrt
.
data
=
SIMD_SET
(
_bias_correction2
);
AVX_Data
eps_4
;
eps_4
.
data
=
SIMD_SET
(
_eps
);
AVX_Data
eps_4
;
eps_4
.
data
=
SIMD_SET
(
_eps
);
AVX_Data
step_size_4
;
step_size_4
.
data
=
SIMD_SET
(
step_size
);
AVX_Data
step_size_4
;
step_size_4
.
data
=
SIMD_SET
(
step_size
);
AVX_Data
weight_decay_4
;
if
(
_weight_decay
>
0
)
weight_decay_4
.
data
=
(
_adamw_mode
?
SIMD_SET
(
w_decay
)
:
SIMD_SET
(
_weight_decay
));
rounded_size
=
ROUND_DOWN
(
_param_size
,
SIMD_WIDTH
);
AVX_Data
weight_decay_4
;
if
(
_weight_decay
>
0
)
weight_decay_4
.
data
=
(
_adamw_mode
?
SIMD_SET
(
w_decay
)
:
SIMD_SET
(
_weight_decay
));
rounded_size
=
ROUND_DOWN
(
_param_size
,
SIMD_WIDTH
);
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
size_t
offset
=
copy_size
+
t
;
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
size_t
offset
=
copy_size
+
t
;
#pragma omp parallel for
for
(
size_t
i
=
t
;
i
<
offset
;
i
+=
SIMD_WIDTH
)
{
AVX_Data
grad_4
;
if
(
grad_half_precision
)
{
grad_4
.
data
=
SIMD_LOAD_HALF
(
grads_cast_h
+
i
);
}
else
{
grad_4
.
data
=
SIMD_LOAD
(
grads
+
i
);
}
if
(
loss_scale
>
0
)
{
AVX_Data
loss_scale_vec
;
loss_scale_vec
.
data
=
SIMD_SET
(
loss_scale
);
grad_4
.
data
=
SIMD_DIV
(
grad_4
.
data
,
loss_scale_vec
.
data
);
}
AVX_Data
momentum_4
;
momentum_4
.
data
=
SIMD_LOAD
(
_exp_avg
+
i
);
AVX_Data
variance_4
;
variance_4
.
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
);
AVX_Data
param_4
;
if
(
param_half_precision
)
{
param_4
.
data
=
SIMD_LOAD_HALF
(
params_cast_h
+
i
);
}
else
{
param_4
.
data
=
SIMD_LOAD
(
_params
+
i
);
}
if
(
_weight_decay
>
0
&&
!
_adamw_mode
)
{
grad_4
.
data
=
SIMD_FMA
(
param_4
.
data
,
weight_decay_4
.
data
,
grad_4
.
data
);
}
momentum_4
.
data
=
SIMD_MUL
(
momentum_4
.
data
,
betta1_4
.
data
);
momentum_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
betta1_minus1_4
.
data
,
momentum_4
.
data
);
variance_4
.
data
=
SIMD_MUL
(
variance_4
.
data
,
betta2_4
.
data
);
grad_4
.
data
=
SIMD_MUL
(
grad_4
.
data
,
grad_4
.
data
);
variance_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
betta2_minus1_4
.
data
,
variance_4
.
data
);
grad_4
.
data
=
SIMD_SQRT
(
variance_4
.
data
);
grad_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
.
data
=
SIMD_DIV
(
momentum_4
.
data
,
grad_4
.
data
);
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
param_4
.
data
=
SIMD_FMA
(
param_4
.
data
,
weight_decay_4
.
data
,
param_4
.
data
);
}
param_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
step_size_4
.
data
,
param_4
.
data
);
if
(
param_half_precision
)
{
SIMD_STORE_HALF
((
float
*
)(
params_cast_h
+
i
),
param_4
.
data
);
}
else
{
SIMD_STORE
(
_params
+
i
,
param_4
.
data
);
}
SIMD_STORE
(
_exp_avg
+
i
,
momentum_4
.
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
,
variance_4
.
data
);
}
for
(
size_t
i
=
t
;
i
<
offset
;
i
+=
SIMD_WIDTH
)
{
AVX_Data
grad_4
;
if
(
grad_half_precision
)
{
grad_4
.
data
=
SIMD_LOAD_HALF
(
grads_cast_h
+
i
);
}
else
{
grad_4
.
data
=
SIMD_LOAD
(
grads
+
i
);
}
if
(
loss_scale
>
0
)
{
AVX_Data
loss_scale_vec
;
loss_scale_vec
.
data
=
SIMD_SET
(
loss_scale
);
grad_4
.
data
=
SIMD_DIV
(
grad_4
.
data
,
loss_scale_vec
.
data
);
}
AVX_Data
momentum_4
;
momentum_4
.
data
=
SIMD_LOAD
(
_exp_avg
+
i
);
AVX_Data
variance_4
;
variance_4
.
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
);
AVX_Data
param_4
;
if
(
param_half_precision
)
{
param_4
.
data
=
SIMD_LOAD_HALF
(
params_cast_h
+
i
);
}
else
{
param_4
.
data
=
SIMD_LOAD
(
_params
+
i
);
}
if
(
_weight_decay
>
0
&&
!
_adamw_mode
)
{
grad_4
.
data
=
SIMD_FMA
(
param_4
.
data
,
weight_decay_4
.
data
,
grad_4
.
data
);
}
momentum_4
.
data
=
SIMD_MUL
(
momentum_4
.
data
,
betta1_4
.
data
);
momentum_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
betta1_minus1_4
.
data
,
momentum_4
.
data
);
variance_4
.
data
=
SIMD_MUL
(
variance_4
.
data
,
betta2_4
.
data
);
grad_4
.
data
=
SIMD_MUL
(
grad_4
.
data
,
grad_4
.
data
);
variance_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
betta2_minus1_4
.
data
,
variance_4
.
data
);
grad_4
.
data
=
SIMD_SQRT
(
variance_4
.
data
);
grad_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
.
data
=
SIMD_DIV
(
momentum_4
.
data
,
grad_4
.
data
);
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
param_4
.
data
=
SIMD_FMA
(
param_4
.
data
,
weight_decay_4
.
data
,
param_4
.
data
);
}
param_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
step_size_4
.
data
,
param_4
.
data
);
if
(
param_half_precision
)
{
SIMD_STORE_HALF
((
float
*
)(
params_cast_h
+
i
),
param_4
.
data
);
}
else
{
SIMD_STORE
(
_params
+
i
,
param_4
.
data
);
}
SIMD_STORE
(
_exp_avg
+
i
,
momentum_4
.
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
,
variance_4
.
data
);
}
}
#endif
if
(
_param_size
>
rounded_size
)
{
for
(
size_t
t
=
rounded_size
;
t
<
_param_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
_param_size
)
copy_size
=
_param_size
-
t
;
size_t
offset
=
copy_size
+
t
;
if
(
_param_size
>
rounded_size
)
{
for
(
size_t
t
=
rounded_size
;
t
<
_param_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
_param_size
)
copy_size
=
_param_size
-
t
;
size_t
offset
=
copy_size
+
t
;
#pragma omp parallel for
for
(
size_t
k
=
t
;
k
<
offset
;
k
++
)
{
float
grad
=
grad_half_precision
?
(
float
)
grads_cast_h
[
k
]
:
grads
[
k
];
if
(
loss_scale
>
0
)
{
grad
/=
loss_scale
;
}
float
param
=
param_half_precision
?
(
float
)
params_cast_h
[
k
]
:
_params
[
k
];
float
momentum
=
_exp_avg
[
k
];
float
variance
=
_exp_avg_sq
[
k
];
if
(
_weight_decay
>
0
&&
!
_adamw_mode
)
{
grad
=
param
*
_weight_decay
+
grad
;
}
momentum
=
momentum
*
_betta1
;
momentum
=
grad
*
betta1_minus1
+
momentum
;
variance
=
variance
*
_betta2
;
grad
=
grad
*
grad
;
variance
=
grad
*
betta2_minus1
+
variance
;
grad
=
sqrt
(
variance
);
grad
=
grad
*
_bias_correction2
+
_eps
;
grad
=
momentum
/
grad
;
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
param
+=
w_decay
*
param
;
}
param
=
grad
*
step_size
+
param
;
if
(
param_half_precision
)
params_cast_h
[
k
]
=
(
__half
)
param
;
else
_params
[
k
]
=
param
;
_exp_avg
[
k
]
=
momentum
;
_exp_avg_sq
[
k
]
=
variance
;
}
for
(
size_t
k
=
t
;
k
<
offset
;
k
++
)
{
float
grad
=
grad_half_precision
?
(
float
)
grads_cast_h
[
k
]
:
grads
[
k
];
if
(
loss_scale
>
0
)
{
grad
/=
loss_scale
;
}
float
param
=
param_half_precision
?
(
float
)
params_cast_h
[
k
]
:
_params
[
k
];
float
momentum
=
_exp_avg
[
k
];
float
variance
=
_exp_avg_sq
[
k
];
if
(
_weight_decay
>
0
&&
!
_adamw_mode
)
{
grad
=
param
*
_weight_decay
+
grad
;
}
momentum
=
momentum
*
_betta1
;
momentum
=
grad
*
betta1_minus1
+
momentum
;
variance
=
variance
*
_betta2
;
grad
=
grad
*
grad
;
variance
=
grad
*
betta2_minus1
+
variance
;
grad
=
sqrt
(
variance
);
grad
=
grad
*
_bias_correction2
+
_eps
;
grad
=
momentum
/
grad
;
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
param
+=
w_decay
*
param
;
}
param
=
grad
*
step_size
+
param
;
if
(
param_half_precision
)
params_cast_h
[
k
]
=
(
__half
)
param
;
else
_params
[
k
]
=
param
;
_exp_avg
[
k
]
=
momentum
;
_exp_avg_sq
[
k
]
=
variance
;
}
}
}
}
void
Adam_Optimizer
::
Step_4
(
float
*
_params
,
float
*
grads
,
float
*
_exp_avg
,
float
*
_exp_avg_sq
,
size_t
_param_size
,
bool
param_half_precision
,
bool
grad_half_precision
,
float
loss_scale
)
{
size_t
rounded_size
=
0
;
__half
*
params_cast_h
=
NULL
;
__half
*
grads_cast_h
=
NULL
;
if
(
param_half_precision
)
{
params_cast_h
=
reinterpret_cast
<
__half
*>
(
_params
);
}
if
(
grad_half_precision
)
{
grads_cast_h
=
reinterpret_cast
<
__half
*>
(
grads
);
}
void
Adam_Optimizer
::
Step_4
(
float
*
_params
,
float
*
grads
,
float
*
_exp_avg
,
float
*
_exp_avg_sq
,
size_t
_param_size
,
bool
param_half_precision
,
bool
grad_half_precision
,
float
loss_scale
)
{
size_t
rounded_size
=
0
;
__half
*
params_cast_h
=
NULL
;
__half
*
grads_cast_h
=
NULL
;
if
(
param_half_precision
)
{
params_cast_h
=
reinterpret_cast
<
__half
*>
(
_params
);
}
if
(
grad_half_precision
)
{
grads_cast_h
=
reinterpret_cast
<
__half
*>
(
grads
);
}
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data
betta1_4
;
betta1_4
.
data
=
SIMD_SET
(
_betta1
);
AVX_Data
betta2_4
;
betta2_4
.
data
=
SIMD_SET
(
_betta2
);
float
betta1_minus1
=
1
-
_betta1
;
AVX_Data
betta1_minus1_4
;
betta1_minus1_4
.
data
=
SIMD_SET
(
betta1_minus1
);
float
betta2_minus1
=
1
-
_betta2
;
AVX_Data
betta2_minus1_4
;
betta2_minus1_4
.
data
=
SIMD_SET
(
betta2_minus1
);
AVX_Data
bias2_sqrt
;
bias2_sqrt
.
data
=
SIMD_SET
(
_bias_correction2
);
AVX_Data
eps_4
;
eps_4
.
data
=
SIMD_SET
(
_eps
);
float
step_size
=
-
1
*
_alpha
/
_bias_correction1
;
AVX_Data
step_size_4
;
step_size_4
.
data
=
SIMD_SET
(
step_size
);
float
w_decay
=
-
1
*
_alpha
*
_weight_decay
;
AVX_Data
weight_decay_4
;
if
(
_weight_decay
>
0
)
weight_decay_4
.
data
=
(
_adamw_mode
?
SIMD_SET
(
w_decay
)
:
SIMD_SET
(
_weight_decay
));
rounded_size
=
ROUND_DOWN
(
_param_size
,
SIMD_WIDTH
*
4
);
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
size_t
offset
=
copy_size
+
t
;
AVX_Data
betta1_4
;
betta1_4
.
data
=
SIMD_SET
(
_betta1
);
AVX_Data
betta2_4
;
betta2_4
.
data
=
SIMD_SET
(
_betta2
);
float
betta1_minus1
=
1
-
_betta1
;
AVX_Data
betta1_minus1_4
;
betta1_minus1_4
.
data
=
SIMD_SET
(
betta1_minus1
);
float
betta2_minus1
=
1
-
_betta2
;
AVX_Data
betta2_minus1_4
;
betta2_minus1_4
.
data
=
SIMD_SET
(
betta2_minus1
);
AVX_Data
bias2_sqrt
;
bias2_sqrt
.
data
=
SIMD_SET
(
_bias_correction2
);
AVX_Data
eps_4
;
eps_4
.
data
=
SIMD_SET
(
_eps
);
float
step_size
=
-
1
*
_alpha
/
_bias_correction1
;
AVX_Data
step_size_4
;
step_size_4
.
data
=
SIMD_SET
(
step_size
);
float
w_decay
=
-
1
*
_alpha
*
_weight_decay
;
AVX_Data
weight_decay_4
;
if
(
_weight_decay
>
0
)
weight_decay_4
.
data
=
(
_adamw_mode
?
SIMD_SET
(
w_decay
)
:
SIMD_SET
(
_weight_decay
));
rounded_size
=
ROUND_DOWN
(
_param_size
,
SIMD_WIDTH
*
4
);
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
size_t
offset
=
copy_size
+
t
;
#pragma omp parallel for
for
(
size_t
i
=
t
;
i
<
offset
;
i
+=
SIMD_WIDTH
*
4
)
{
AVX_Data
grad_4
[
4
];
AVX_Data
momentum_4
[
4
];
AVX_Data
variance_4
[
4
];
AVX_Data
param_4
[
4
];
for
(
size_t
i
=
t
;
i
<
offset
;
i
+=
SIMD_WIDTH
*
4
)
{
AVX_Data
grad_4
[
4
];
AVX_Data
momentum_4
[
4
];
AVX_Data
variance_4
[
4
];
AVX_Data
param_4
[
4
];
#pragma unroll 4
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
if
(
grad_half_precision
)
{
grad_4
[
j
].
data
=
SIMD_LOAD_HALF
(
grads_cast_h
+
i
+
SIMD_WIDTH
*
j
);
}
else
{
grad_4
[
j
].
data
=
SIMD_LOAD
(
grads
+
i
+
SIMD_WIDTH
*
j
);
}
if
(
loss_scale
>
0
)
{
AVX_Data
loss_scale_vec
;
loss_scale_vec
.
data
=
SIMD_SET
(
loss_scale
);
grad_4
[
j
].
data
=
SIMD_DIV
(
grad_4
[
j
].
data
,
loss_scale_vec
.
data
);
}
momentum_4
[
j
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
+
SIMD_WIDTH
*
j
);
variance_4
[
j
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
j
);
if
(
param_half_precision
)
{
param_4
[
j
].
data
=
SIMD_LOAD_HALF
(
params_cast_h
+
i
+
SIMD_WIDTH
*
j
);
}
else
{
param_4
[
j
].
data
=
SIMD_LOAD
(
_params
+
i
+
SIMD_WIDTH
*
j
);
}
if
(
_weight_decay
>
0
&&
!
_adamw_mode
)
{
grad_4
[
j
].
data
=
SIMD_FMA
(
param_4
[
j
].
data
,
weight_decay_4
.
data
,
grad_4
[
j
].
data
);
}
momentum_4
[
j
].
data
=
SIMD_MUL
(
momentum_4
[
j
].
data
,
betta1_4
.
data
);
momentum_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
j
].
data
);
variance_4
[
j
].
data
=
SIMD_MUL
(
variance_4
[
j
].
data
,
betta2_4
.
data
);
grad_4
[
j
].
data
=
SIMD_MUL
(
grad_4
[
j
].
data
,
grad_4
[
j
].
data
);
variance_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
j
].
data
);
grad_4
[
j
].
data
=
SIMD_SQRT
(
variance_4
[
j
].
data
);
grad_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
j
].
data
=
SIMD_DIV
(
momentum_4
[
j
].
data
,
grad_4
[
j
].
data
);
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
param_4
[
j
].
data
=
SIMD_FMA
(
param_4
[
j
].
data
,
weight_decay_4
.
data
,
param_4
[
j
].
data
);
}
param_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
step_size_4
.
data
,
param_4
[
j
].
data
);
if
(
param_half_precision
)
{
SIMD_STORE_HALF
((
float
*
)(
params_cast_h
+
i
+
SIMD_WIDTH
*
j
),
param_4
[
j
].
data
);
}
else
{
SIMD_STORE
(
_params
+
i
+
SIMD_WIDTH
*
j
,
param_4
[
j
].
data
);
}
SIMD_STORE
(
_exp_avg
+
i
+
SIMD_WIDTH
*
j
,
momentum_4
[
j
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
j
,
variance_4
[
j
].
data
);
}
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
if
(
grad_half_precision
)
{
grad_4
[
j
].
data
=
SIMD_LOAD_HALF
(
grads_cast_h
+
i
+
SIMD_WIDTH
*
j
);
}
else
{
grad_4
[
j
].
data
=
SIMD_LOAD
(
grads
+
i
+
SIMD_WIDTH
*
j
);
}
if
(
loss_scale
>
0
)
{
AVX_Data
loss_scale_vec
;
loss_scale_vec
.
data
=
SIMD_SET
(
loss_scale
);
grad_4
[
j
].
data
=
SIMD_DIV
(
grad_4
[
j
].
data
,
loss_scale_vec
.
data
);
}
momentum_4
[
j
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
+
SIMD_WIDTH
*
j
);
variance_4
[
j
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
j
);
if
(
param_half_precision
)
{
param_4
[
j
].
data
=
SIMD_LOAD_HALF
(
params_cast_h
+
i
+
SIMD_WIDTH
*
j
);
}
else
{
param_4
[
j
].
data
=
SIMD_LOAD
(
_params
+
i
+
SIMD_WIDTH
*
j
);
}
if
(
_weight_decay
>
0
&&
!
_adamw_mode
)
{
grad_4
[
j
].
data
=
SIMD_FMA
(
param_4
[
j
].
data
,
weight_decay_4
.
data
,
grad_4
[
j
].
data
);
}
momentum_4
[
j
].
data
=
SIMD_MUL
(
momentum_4
[
j
].
data
,
betta1_4
.
data
);
momentum_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
j
].
data
);
variance_4
[
j
].
data
=
SIMD_MUL
(
variance_4
[
j
].
data
,
betta2_4
.
data
);
grad_4
[
j
].
data
=
SIMD_MUL
(
grad_4
[
j
].
data
,
grad_4
[
j
].
data
);
variance_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
j
].
data
);
grad_4
[
j
].
data
=
SIMD_SQRT
(
variance_4
[
j
].
data
);
grad_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
j
].
data
=
SIMD_DIV
(
momentum_4
[
j
].
data
,
grad_4
[
j
].
data
);
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
param_4
[
j
].
data
=
SIMD_FMA
(
param_4
[
j
].
data
,
weight_decay_4
.
data
,
param_4
[
j
].
data
);
}
param_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
step_size_4
.
data
,
param_4
[
j
].
data
);
if
(
param_half_precision
)
{
SIMD_STORE_HALF
((
float
*
)(
params_cast_h
+
i
+
SIMD_WIDTH
*
j
),
param_4
[
j
].
data
);
}
else
{
SIMD_STORE
(
_params
+
i
+
SIMD_WIDTH
*
j
,
param_4
[
j
].
data
);
}
SIMD_STORE
(
_exp_avg
+
i
+
SIMD_WIDTH
*
j
,
momentum_4
[
j
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
j
,
variance_4
[
j
].
data
);
}
}
}
#endif
if
(
_param_size
>
rounded_size
)
Step_1
((
param_half_precision
?
(
float
*
)(
params_cast_h
+
rounded_size
)
:
_params
+
rounded_size
),
(
grad_half_precision
?
(
float
*
)(
grads_cast_h
+
rounded_size
)
:
grads
+
rounded_size
),
(
_exp_avg
+
rounded_size
),
(
_exp_avg_sq
+
rounded_size
),
(
_param_size
-
rounded_size
),
param_half_precision
,
grad_half_precision
,
loss_scale
);
if
(
_param_size
>
rounded_size
)
Step_1
((
param_half_precision
?
(
float
*
)(
params_cast_h
+
rounded_size
)
:
_params
+
rounded_size
),
(
grad_half_precision
?
(
float
*
)(
grads_cast_h
+
rounded_size
)
:
grads
+
rounded_size
),
(
_exp_avg
+
rounded_size
),
(
_exp_avg_sq
+
rounded_size
),
(
_param_size
-
rounded_size
),
param_half_precision
,
grad_half_precision
,
loss_scale
);
}
int
create_adam_optimizer
(
int
optimizer_id
,
float
alpha
=
1e-3
,
float
betta1
=
0.9
,
float
betta2
=
0.999
,
float
eps
=
1e-8
,
float
weight_decay
=
0
,
bool
adamw_mode
=
true
,
bool
should_log
=
false
)
{
auto
opt
=
std
::
make_shared
<
Adam_Optimizer
>
(
alpha
,
betta1
,
betta2
,
eps
,
weight_decay
,
adamw_mode
);
int
create_adam_optimizer
(
int
optimizer_id
,
float
alpha
=
1e-3
,
float
betta1
=
0.9
,
float
betta2
=
0.999
,
float
eps
=
1e-8
,
float
weight_decay
=
0
,
bool
adamw_mode
=
true
,
bool
should_log
=
false
)
{
auto
opt
=
std
::
make_shared
<
Adam_Optimizer
>
(
alpha
,
betta1
,
betta2
,
eps
,
weight_decay
,
adamw_mode
);
s_optimizers
[
optimizer_id
]
=
opt
;
s_optimizers
[
optimizer_id
]
=
opt
;
if
(
should_log
){
if
(
should_log
)
{
std
::
string
avx_type
=
""
;
std
::
string
avx_type
=
""
;
#if defined(__AVX512__)
avx_type
=
"AVX512"
;
avx_type
=
"AVX512"
;
#else
#if defined(__AVX256__) or defined(__AVX2__)
avx_type
=
"AVX2"
;
avx_type
=
"AVX2"
;
#else
avx_type
=
"scalar"
;
avx_type
=
"scalar"
;
#endif
#endif
printf
(
"Adam Optimizer #%d is created with %s arithmetic capability.
\n
"
,
optimizer_id
,
avx_type
.
c_str
());
printf
(
"Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d
\n
"
,
alpha
,
betta1
,
betta2
,
weight_decay
,
(
int
)
adamw_mode
);
}
printf
(
"Adam Optimizer #%d is created with %s arithmetic capability.
\n
"
,
optimizer_id
,
avx_type
.
c_str
());
printf
(
"Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d
\n
"
,
alpha
,
betta1
,
betta2
,
weight_decay
,
(
int
)
adamw_mode
);
}
return
0
;
return
0
;
}
void
Adam_Optimizer
::
Step_8
(
float
*
_params
,
float
*
grads
,
float
*
_exp_avg
,
float
*
_exp_avg_sq
,
size_t
_param_size
,
bool
param_half_precision
,
bool
grad_half_precision
,
float
loss_scale
)
{
size_t
rounded_size
=
0
;
__half
*
params_cast_h
=
NULL
;
__half
*
grads_cast_h
=
NULL
;
if
(
param_half_precision
)
{
params_cast_h
=
reinterpret_cast
<
__half
*>
(
_params
);
}
if
(
grad_half_precision
)
{
grads_cast_h
=
reinterpret_cast
<
__half
*>
(
grads
);
}
void
Adam_Optimizer
::
Step_8
(
float
*
_params
,
float
*
grads
,
float
*
_exp_avg
,
float
*
_exp_avg_sq
,
size_t
_param_size
,
bool
param_half_precision
,
bool
grad_half_precision
,
float
loss_scale
)
{
size_t
rounded_size
=
0
;
__half
*
params_cast_h
=
NULL
;
__half
*
grads_cast_h
=
NULL
;
if
(
param_half_precision
)
{
params_cast_h
=
reinterpret_cast
<
__half
*>
(
_params
);
}
if
(
grad_half_precision
)
{
grads_cast_h
=
reinterpret_cast
<
__half
*>
(
grads
);
}
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data
betta1_4
;
betta1_4
.
data
=
SIMD_SET
(
_betta1
);
AVX_Data
betta2_4
;
betta2_4
.
data
=
SIMD_SET
(
_betta2
);
float
betta1_minus1
=
1
-
_betta1
;
AVX_Data
betta1_minus1_4
;
betta1_minus1_4
.
data
=
SIMD_SET
(
betta1_minus1
);
float
betta2_minus1
=
1
-
_betta2
;
AVX_Data
betta2_minus1_4
;
betta2_minus1_4
.
data
=
SIMD_SET
(
betta2_minus1
);
AVX_Data
bias2_sqrt
;
bias2_sqrt
.
data
=
SIMD_SET
(
_bias_correction2
);
AVX_Data
eps_4
;
eps_4
.
data
=
SIMD_SET
(
_eps
);
float
step_size
=
-
1
*
_alpha
/
_bias_correction1
;
AVX_Data
step_size_4
;
step_size_4
.
data
=
SIMD_SET
(
step_size
);
float
w_decay
=
-
1
*
_alpha
*
_weight_decay
;
AVX_Data
weight_decay_4
;
if
(
_weight_decay
>
0
)
weight_decay_4
.
data
=
(
_adamw_mode
?
SIMD_SET
(
w_decay
)
:
SIMD_SET
(
_weight_decay
));
rounded_size
=
ROUND_DOWN
(
_param_size
,
SIMD_WIDTH
*
8
);
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
size_t
offset
=
copy_size
+
t
;
AVX_Data
betta1_4
;
betta1_4
.
data
=
SIMD_SET
(
_betta1
);
AVX_Data
betta2_4
;
betta2_4
.
data
=
SIMD_SET
(
_betta2
);
float
betta1_minus1
=
1
-
_betta1
;
AVX_Data
betta1_minus1_4
;
betta1_minus1_4
.
data
=
SIMD_SET
(
betta1_minus1
);
float
betta2_minus1
=
1
-
_betta2
;
AVX_Data
betta2_minus1_4
;
betta2_minus1_4
.
data
=
SIMD_SET
(
betta2_minus1
);
AVX_Data
bias2_sqrt
;
bias2_sqrt
.
data
=
SIMD_SET
(
_bias_correction2
);
AVX_Data
eps_4
;
eps_4
.
data
=
SIMD_SET
(
_eps
);
float
step_size
=
-
1
*
_alpha
/
_bias_correction1
;
AVX_Data
step_size_4
;
step_size_4
.
data
=
SIMD_SET
(
step_size
);
float
w_decay
=
-
1
*
_alpha
*
_weight_decay
;
AVX_Data
weight_decay_4
;
if
(
_weight_decay
>
0
)
weight_decay_4
.
data
=
(
_adamw_mode
?
SIMD_SET
(
w_decay
)
:
SIMD_SET
(
_weight_decay
));
rounded_size
=
ROUND_DOWN
(
_param_size
,
SIMD_WIDTH
*
8
);
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
size_t
offset
=
copy_size
+
t
;
#pragma omp parallel for
for
(
size_t
i
=
t
;
i
<
offset
;
i
+=
SIMD_WIDTH
*
8
)
{
AVX_Data
grad_4
[
8
];
AVX_Data
momentum_4
[
8
];
AVX_Data
variance_4
[
8
];
AVX_Data
param_4
[
8
];
for
(
size_t
i
=
t
;
i
<
offset
;
i
+=
SIMD_WIDTH
*
8
)
{
AVX_Data
grad_4
[
8
];
AVX_Data
momentum_4
[
8
];
AVX_Data
variance_4
[
8
];
AVX_Data
param_4
[
8
];
#pragma unroll 8
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
if
(
grad_half_precision
)
{
grad_4
[
j
].
data
=
SIMD_LOAD_HALF
(
grads_cast_h
+
i
+
SIMD_WIDTH
*
j
);
}
else
{
grad_4
[
j
].
data
=
SIMD_LOAD
(
grads
+
i
+
SIMD_WIDTH
*
j
);
}
if
(
loss_scale
>
0
)
{
AVX_Data
loss_scale_vec
;
loss_scale_vec
.
data
=
SIMD_SET
(
loss_scale
);
grad_4
[
j
].
data
=
SIMD_DIV
(
grad_4
[
j
].
data
,
loss_scale_vec
.
data
);
}
momentum_4
[
j
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
+
SIMD_WIDTH
*
j
);
variance_4
[
j
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
j
);
if
(
param_half_precision
)
{
param_4
[
j
].
data
=
SIMD_LOAD_HALF
(
params_cast_h
+
i
+
SIMD_WIDTH
*
j
);
}
else
{
param_4
[
j
].
data
=
SIMD_LOAD
(
_params
+
i
+
SIMD_WIDTH
*
j
);
}
if
(
_weight_decay
>
0
&&
!
_adamw_mode
)
{
grad_4
[
j
].
data
=
SIMD_FMA
(
param_4
[
j
].
data
,
weight_decay_4
.
data
,
grad_4
[
j
].
data
);
}
momentum_4
[
j
].
data
=
SIMD_MUL
(
momentum_4
[
j
].
data
,
betta1_4
.
data
);
momentum_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
j
].
data
);
variance_4
[
j
].
data
=
SIMD_MUL
(
variance_4
[
j
].
data
,
betta2_4
.
data
);
grad_4
[
j
].
data
=
SIMD_MUL
(
grad_4
[
j
].
data
,
grad_4
[
j
].
data
);
variance_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
j
].
data
);
grad_4
[
j
].
data
=
SIMD_SQRT
(
variance_4
[
j
].
data
);
grad_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
j
].
data
=
SIMD_DIV
(
momentum_4
[
j
].
data
,
grad_4
[
j
].
data
);
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
param_4
[
j
].
data
=
SIMD_FMA
(
param_4
[
j
].
data
,
weight_decay_4
.
data
,
param_4
[
j
].
data
);
}
param_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
step_size_4
.
data
,
param_4
[
j
].
data
);
if
(
param_half_precision
)
{
SIMD_STORE_HALF
((
float
*
)(
params_cast_h
+
i
+
SIMD_WIDTH
*
j
),
param_4
[
j
].
data
);
}
else
{
SIMD_STORE
(
_params
+
i
+
SIMD_WIDTH
*
j
,
param_4
[
j
].
data
);
}
SIMD_STORE
(
_exp_avg
+
i
+
(
SIMD_WIDTH
*
j
),
momentum_4
[
j
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
+
(
SIMD_WIDTH
*
j
),
variance_4
[
j
].
data
);
}
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
if
(
grad_half_precision
)
{
grad_4
[
j
].
data
=
SIMD_LOAD_HALF
(
grads_cast_h
+
i
+
SIMD_WIDTH
*
j
);
}
else
{
grad_4
[
j
].
data
=
SIMD_LOAD
(
grads
+
i
+
SIMD_WIDTH
*
j
);
}
if
(
loss_scale
>
0
)
{
AVX_Data
loss_scale_vec
;
loss_scale_vec
.
data
=
SIMD_SET
(
loss_scale
);
grad_4
[
j
].
data
=
SIMD_DIV
(
grad_4
[
j
].
data
,
loss_scale_vec
.
data
);
}
momentum_4
[
j
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
+
SIMD_WIDTH
*
j
);
variance_4
[
j
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
j
);
if
(
param_half_precision
)
{
param_4
[
j
].
data
=
SIMD_LOAD_HALF
(
params_cast_h
+
i
+
SIMD_WIDTH
*
j
);
}
else
{
param_4
[
j
].
data
=
SIMD_LOAD
(
_params
+
i
+
SIMD_WIDTH
*
j
);
}
if
(
_weight_decay
>
0
&&
!
_adamw_mode
)
{
grad_4
[
j
].
data
=
SIMD_FMA
(
param_4
[
j
].
data
,
weight_decay_4
.
data
,
grad_4
[
j
].
data
);
}
momentum_4
[
j
].
data
=
SIMD_MUL
(
momentum_4
[
j
].
data
,
betta1_4
.
data
);
momentum_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
j
].
data
);
variance_4
[
j
].
data
=
SIMD_MUL
(
variance_4
[
j
].
data
,
betta2_4
.
data
);
grad_4
[
j
].
data
=
SIMD_MUL
(
grad_4
[
j
].
data
,
grad_4
[
j
].
data
);
variance_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
j
].
data
);
grad_4
[
j
].
data
=
SIMD_SQRT
(
variance_4
[
j
].
data
);
grad_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
j
].
data
=
SIMD_DIV
(
momentum_4
[
j
].
data
,
grad_4
[
j
].
data
);
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
param_4
[
j
].
data
=
SIMD_FMA
(
param_4
[
j
].
data
,
weight_decay_4
.
data
,
param_4
[
j
].
data
);
}
param_4
[
j
].
data
=
SIMD_FMA
(
grad_4
[
j
].
data
,
step_size_4
.
data
,
param_4
[
j
].
data
);
if
(
param_half_precision
)
{
SIMD_STORE_HALF
((
float
*
)(
params_cast_h
+
i
+
SIMD_WIDTH
*
j
),
param_4
[
j
].
data
);
}
else
{
SIMD_STORE
(
_params
+
i
+
SIMD_WIDTH
*
j
,
param_4
[
j
].
data
);
}
SIMD_STORE
(
_exp_avg
+
i
+
(
SIMD_WIDTH
*
j
),
momentum_4
[
j
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
+
(
SIMD_WIDTH
*
j
),
variance_4
[
j
].
data
);
}
}
}
#endif
if
(
_param_size
>
rounded_size
)
Step_4
((
param_half_precision
?
(
float
*
)(
params_cast_h
+
rounded_size
)
:
_params
+
rounded_size
),
(
grad_half_precision
?
(
float
*
)(
grads_cast_h
+
rounded_size
)
:
grads
+
rounded_size
),
(
_exp_avg
+
rounded_size
),
(
_exp_avg_sq
+
rounded_size
),
(
_param_size
-
rounded_size
),
param_half_precision
,
grad_half_precision
,
loss_scale
);
if
(
_param_size
>
rounded_size
)
Step_4
((
param_half_precision
?
(
float
*
)(
params_cast_h
+
rounded_size
)
:
_params
+
rounded_size
),
(
grad_half_precision
?
(
float
*
)(
grads_cast_h
+
rounded_size
)
:
grads
+
rounded_size
),
(
_exp_avg
+
rounded_size
),
(
_exp_avg_sq
+
rounded_size
),
(
_param_size
-
rounded_size
),
param_half_precision
,
grad_half_precision
,
loss_scale
);
}
int
adam_step
(
int
optimizer_id
,
...
...
@@ -501,17 +502,13 @@ int adam_step(int optimizer_id,
return
0
;
}
int
destroy_adam_optimizer
(
int
optimizer_id
)
{
s_optimizers
.
erase
(
optimizer_id
);
return
0
;
int
destroy_adam_optimizer
(
int
optimizer_id
)
{
s_optimizers
.
erase
(
optimizer_id
);
return
0
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"adam_update"
,
&
adam_step
,
"CPU Adam update (C++)"
);
m
.
def
(
"create_adam"
,
&
create_adam_optimizer
,
"CPU Adam (C++)"
);
m
.
def
(
"destroy_adam"
,
&
destroy_adam_optimizer
,
"CPU Adam destroy (C++)"
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"adam_update"
,
&
adam_step
,
"CPU Adam update (C++)"
);
m
.
def
(
"create_adam"
,
&
create_adam_optimizer
,
"CPU Adam (C++)"
);
m
.
def
(
"destroy_adam"
,
&
destroy_adam_optimizer
,
"CPU Adam destroy (C++)"
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment