Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f2d2a159
Commit
f2d2a159
authored
Mar 31, 2022
by
xuqifan897
Committed by
binmakeswell
Apr 06, 2022
Browse files
fix format (#608)
parent
f2da21a8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
71 additions
and
81 deletions
+71
-81
colossalai/kernel/cuda_native/csrc/cpu_adam.h
colossalai/kernel/cuda_native/csrc/cpu_adam.h
+71
-81
No files found.
colossalai/kernel/cuda_native/csrc/cpu_adam.h
View file @
f2d2a159
...
@@ -21,11 +21,11 @@ SOFTWARE
...
@@ -21,11 +21,11 @@ SOFTWARE
*/
*/
#pragma once
#pragma once
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <cuda_runtime_api.h>
#include <stdio.h>
#include <stdio.h>
#include <cuda.h>
#include <cublas_v2.h>
#if (__x86_64__ || __i386__)
#if (__x86_64__ || __i386__)
#include <cpuid.h>
#include <cpuid.h>
...
@@ -48,8 +48,11 @@ SOFTWARE
...
@@ -48,8 +48,11 @@ SOFTWARE
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
#define SIMD_LOAD_HALF(x) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(x)))
#define SIMD_LOAD_HALF(x) \
#define SIMD_STORE_HALF(x, d) _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
#define SIMD_STORE_HALF(x, d) \
_mm256_store_ps( \
x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
#elif defined(__AVX256__) or defined(__AVX2__)
#elif defined(__AVX256__) or defined(__AVX2__)
#define SIMD_WIDTH 8
#define SIMD_WIDTH 8
...
@@ -62,102 +65,89 @@ SOFTWARE
...
@@ -62,102 +65,89 @@ SOFTWARE
#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c)
#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)(x)))
#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
#define SIMD_STORE_HALF(x, d) _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
#define SIMD_STORE_HALF(x, d) \
_mm_store_ps( \
x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
#endif
#endif
union
AVX_Data
{
union
AVX_Data
{
#if defined(__AVX512__)
#if defined(__AVX512__)
__m512
data
;
__m512
data
;
#elif defined(__AVX256__) or defined(__AVX2__)
#elif defined(__AVX256__) or defined(__AVX2__)
__m256
data
;
__m256
data
;
#endif
#endif
// float data_f[16];
// float data_f[16];
};
};
#endif
#endif
#define STEP(SPAN) \
#define STEP(SPAN) \
void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \
void Step_##SPAN(float* _params, \
float *_exp_avg_sq, size_t _param_size, \
float* grads, \
bool param_half_precision = false, \
float* _exp_avg, \
bool grad_half_precision = false, float loss_scale = -1);
float* _exp_avg_sq, \
size_t _param_size, \
bool param_half_precision = false, \
bool grad_half_precision = false, \
float loss_scale = -1);
class
Adam_Optimizer
{
class
Adam_Optimizer
{
public:
public:
Adam_Optimizer
(
float
alpha
=
1e-3
,
Adam_Optimizer
(
float
alpha
=
1e-3
,
float
betta1
=
0.9
,
float
betta2
=
0.999
,
float
betta1
=
0.9
,
float
eps
=
1e-8
,
float
weight_decay
=
0
,
float
betta2
=
0.999
,
bool
adamw_mode
=
true
)
float
eps
=
1e-8
,
:
_alpha
(
alpha
),
_betta1
(
betta1
),
_betta2
(
betta2
),
_eps
(
eps
),
float
weight_decay
=
0
,
_weight_decay
(
weight_decay
),
_betta1_t
(
1.0
),
_betta2_t
(
1.0
),
_step
(
0
),
bool
adamw_mode
=
true
)
_adamw_mode
(
adamw_mode
)
{}
:
_alpha
(
alpha
),
~
Adam_Optimizer
()
{}
_betta1
(
betta1
),
_betta2
(
betta2
),
STEP
(
1
)
_eps
(
eps
),
STEP
(
4
)
_weight_decay
(
weight_decay
),
STEP
(
8
)
_betta1_t
(
1.0
),
inline
void
IncrementStep
(
size_t
step
,
float
beta1
,
float
beta2
)
{
_betta2_t
(
1.0
),
if
(
beta1
!=
_betta1
||
beta2
!=
_betta2
)
{
_step
(
0
),
_step
=
step
;
_adamw_mode
(
adamw_mode
){}
_betta1
=
beta1
;
~
Adam_Optimizer
(){}
_betta2
=
beta2
;
_betta1_t
=
std
::
pow
(
_betta1
,
step
);
STEP
(
1
)
_betta2_t
=
std
::
pow
(
_betta2
,
step
);
STEP
(
4
)
}
else
{
STEP
(
8
)
_step
++
;
inline
void
IncrementStep
(
size_t
step
,
float
beta1
,
float
beta2
)
if
(
_step
!=
step
)
{
{
_betta1_t
=
std
::
pow
(
_betta1
,
step
);
if
(
beta1
!=
_betta1
||
beta2
!=
_betta2
)
{
_betta2_t
=
std
::
pow
(
_betta2
,
step
);
_step
=
step
;
_step
=
step
;
_betta1
=
beta1
;
}
else
{
_betta2
=
beta2
;
_betta1_t
*=
_betta1
;
_betta1_t
=
std
::
pow
(
_betta1
,
step
);
_betta2_t
*=
_betta2
;
_betta2_t
=
std
::
pow
(
_betta2
,
step
);
}
}
else
{
_step
++
;
if
(
_step
!=
step
)
{
_betta1_t
=
std
::
pow
(
_betta1
,
step
);
_betta2_t
=
std
::
pow
(
_betta2
,
step
);
_step
=
step
;
}
else
{
_betta1_t
*=
_betta1
;
_betta2_t
*=
_betta2
;
}
}
}
}
inline
void
update_state
(
float
lr
,
float
epsilon
,
float
weight_decay
,
bool
bias_correction
)
}
{
inline
void
update_state
(
float
lr
,
float
epsilon
,
float
weight_decay
,
_alpha
=
lr
;
bool
bias_correction
)
{
_
eps
=
epsilon
;
_
alpha
=
lr
;
_
weight_decay
=
weight_decay
;
_
eps
=
epsilon
;
_weight_decay
=
weight_decay
;
_bias_correction1
=
1.0
f
;
_bias_correction
2
=
1.0
f
;
_bias_correction
1
=
1.0
f
;
if
(
bias_correction
=
=
1
)
{
_
bias_correction
2
=
1
.0
f
;
_
bias_correction
1
=
1
-
_betta1_t
;
if
(
bias_correction
=
=
1
)
{
_bias_correction
2
=
1
/
sqrt
(
1
-
_betta
2
_t
)
;
_bias_correction
1
=
1
-
_betta
1
_t
;
}
_bias_correction2
=
1
/
sqrt
(
1
-
_betta2_t
);
}
}
}
private:
private:
float
_alpha
;
float
_alpha
;
float
_betta1
;
float
_betta1
;
float
_betta2
;
float
_betta2
;
float
_eps
;
float
_eps
;
float
_weight_decay
;
float
_weight_decay
;
float
_betta1_t
;
float
_betta1_t
;
float
_betta2_t
;
float
_betta2_t
;
size_t
_step
;
size_t
_step
;
float
_bias_correction1
;
float
_bias_correction1
;
float
_bias_correction2
;
float
_bias_correction2
;
bool
_adamw_mode
;
bool
_adamw_mode
;
};
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment