Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
6c377b39
Commit
6c377b39
authored
Mar 10, 2023
by
Phil Wang
Browse files
always pass beta2 into all the 1state functions
parent
abbe65ad
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
20 deletions
+22
-20
csrc/kernels.cu
csrc/kernels.cu
+10
-8
csrc/kernels.cuh
csrc/kernels.cuh
+4
-4
csrc/ops.cu
csrc/ops.cu
+8
-8
No files found.
csrc/kernels.cu
View file @
6c377b39
...
...
@@ -751,7 +751,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__launch_bounds__
(
BLOCK_SIZE
/
NUM_VALS
,
1
)
__global__
void
kPreconditionOptimizer32bit1State
(
T
*
g
,
T
*
p
,
float
*
state1
,
float
*
unorm
,
const
float
beta1
,
const
float
eps
,
const
float
weight_decay
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
float
weight_decay
,
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
int
n
)
{
...
...
@@ -833,7 +833,7 @@ template<typename T, int OPTIMIZER>
__launch_bounds__
(
TH
,
1
)
__global__
void
kOptimizer32bit1State
(
T
*
g
,
T
*
p
,
float
*
state1
,
float
*
unorm
,
const
float
max_unorm
,
const
float
param_norm
,
const
float
beta1
,
const
float
eps
,
const
float
weight_decay
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
float
weight_decay
,
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
bool
skip_zeros
,
const
int
n
)
{
...
...
@@ -1175,7 +1175,7 @@ __global__ void
__launch_bounds__
(
NUM_THREADS
,
2
)
kPreconditionOptimizerStatic8bit1State
(
T
*
p
,
T
*
__restrict__
const
g
,
unsigned
char
*
__restrict__
const
state1
,
float
*
unorm
,
const
float
beta1
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
int
step
,
float
*
__restrict__
const
quantiles1
,
float
*
max1
,
float
*
new_max1
,
...
...
@@ -1238,7 +1238,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
break
;
case
LION
:
// using eps as beta2
s1_vals
[
j
]
=
s1_vals
[
j
]
*
eps
+
((
1.0
f
-
eps
)
*
g_val
);
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta2
+
((
1.0
f
-
beta2
)
*
g_val
);
break
;
case
RMSPROP
:
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
(
g_val
*
g_val
));
...
...
@@ -1265,7 +1265,7 @@ template<typename T, int OPTIMIZER>
__global__
void
kOptimizerStatic8bit1State
(
T
*
p
,
T
*
const
g
,
unsigned
char
*
state1
,
const
float
*
unorm
,
const
float
max_unorm
,
const
float
param_norm
,
const
float
beta1
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
int
step
,
const
float
lr
,
float
*
__restrict__
const
quantiles1
,
float
*
max1
,
float
*
new_max1
,
...
...
@@ -1356,7 +1356,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
case
LION
:
// using eps as beta2
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
(
lr
*
sgn
(((
float
)
s1_vals
[
j
])
*
beta1
+
((
1.0
f
-
beta1
)
*
((
float
)
g_val
))));
s1_vals
[
j
]
=
s1_vals
[
j
]
*
eps
+
((
1.0
f
-
eps
)
*
g_val
);
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta2
+
((
1.0
f
-
beta2
)
*
g_val
);
break
;
case
RMSPROP
:
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
(
g_val
*
g_val
));
...
...
@@ -2745,7 +2745,7 @@ template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *c
#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \
template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
float* state1, float *unorm, \
const float beta1, const float eps, const float weight_decay, \
const float beta1, const float
beta2, const float
eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, const int n); \
MAKE_PreconditionOptimizer32bit1State
(
MOMENTUM
,
half
)
...
...
@@ -2759,7 +2759,7 @@ MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
#define MAKE_Optimizer32bit1State(oname, gtype) \
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
const float beta1, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
const float beta1, const float
beta2, const float
eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
MAKE_Optimizer32bit1State
(
MOMENTUM
,
half
)
MAKE_Optimizer32bit1State
(
MOMENTUM
,
float
)
...
...
@@ -2788,6 +2788,7 @@ template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p,
template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \
float *unorm, \
const float beta1, \
const float beta2, \
const float eps, const int step, \
float* __restrict__ const quantiles1, \
float* max1, float* new_max1, \
...
...
@@ -2806,6 +2807,7 @@ MAKE_PreconditionStatic8bit1State(LION, float)
template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, \
const float *unorm, const float max_unorm, const float param_norm, \
const float beta1, \
const float beta2, \
const float eps, const int step, const float lr, \
float* __restrict__ const quantiles1, \
float* max1, float* new_max1, \
...
...
csrc/kernels.cuh
View file @
6c377b39
...
...
@@ -32,20 +32,20 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
template
<
typename
T
,
int
OPTIMIZER
,
int
BLOCK_SIZE
,
int
NUM_VALS
>
__global__
void
kPreconditionOptimizer32bit1State
(
T
*
g
,
T
*
p
,
float
*
state1
,
float
*
unorm
,
const
float
beta1
,
const
float
eps
,
const
float
weight_decay
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
float
weight_decay
,
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
int
n
);
template
<
typename
T
,
int
OPTIMIZER
>
__global__
void
kOptimizer32bit1State
(
T
*
g
,
T
*
p
,
float
*
state1
,
float
*
unorm
,
const
float
max_unorm
,
const
float
param_norm
,
const
float
beta1
,
const
float
eps
,
const
float
weight_decay
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
float
weight_decay
,
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
bool
skip_zeros
,
const
int
n
);
template
<
typename
T
,
int
OPTIMIZER
>
__global__
void
kPreconditionOptimizerStatic8bit1State
(
T
*
p
,
T
*
__restrict__
const
g
,
unsigned
char
*
__restrict__
const
state1
,
float
*
unorm
,
const
float
beta1
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
int
step
,
float
*
__restrict__
const
quantiles1
,
float
*
max1
,
float
*
new_max1
,
...
...
@@ -57,7 +57,7 @@ template<typename T, int OPTIMIZER>
__global__
void
kOptimizerStatic8bit1State
(
T
*
p
,
T
*
const
g
,
unsigned
char
*
state1
,
const
float
*
unorm
,
const
float
max_unorm
,
const
float
param_norm
,
const
float
beta1
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
int
step
,
const
float
lr
,
float
*
__restrict__
const
quantiles1
,
float
*
max1
,
float
*
new_max1
,
...
...
csrc/ops.cu
View file @
6c377b39
...
...
@@ -123,22 +123,22 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
if
(
max_unorm
>
0.0
f
)
{
CUDA_CHECK_RETURN
(
cudaMemset
(
unorm
,
0
,
1
*
sizeof
(
float
)));
kPreconditionOptimizer32bit1State
<
T
,
OPTIMIZER
,
4096
,
8
><<<
num_blocks
,
512
>>>
(
g
,
p
,
state1
,
unorm
,
beta1
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
n
);
kPreconditionOptimizer32bit1State
<
T
,
OPTIMIZER
,
4096
,
8
><<<
num_blocks
,
512
>>>
(
g
,
p
,
state1
,
unorm
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
kOptimizer32bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
1024
>>>
(
g
,
p
,
state1
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
skip_zeros
,
n
);
kOptimizer32bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
1024
>>>
(
g
,
p
,
state1
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
skip_zeros
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
break
;
case
LION
:
// in lion, the momentum update after the parameter update
kOptimizer32bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
1024
>>>
(
g
,
p
,
state1
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
beta2
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
skip_zeros
,
n
);
kOptimizer32bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
1024
>>>
(
g
,
p
,
state1
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
skip_zeros
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
if
(
max_unorm
>
0.0
f
)
{
CUDA_CHECK_RETURN
(
cudaMemset
(
unorm
,
0
,
1
*
sizeof
(
float
)));
kPreconditionOptimizer32bit1State
<
T
,
OPTIMIZER
,
4096
,
8
><<<
num_blocks
,
512
>>>
(
g
,
p
,
state1
,
unorm
,
beta1
,
beta2
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
n
);
kPreconditionOptimizer32bit1State
<
T
,
OPTIMIZER
,
4096
,
8
><<<
num_blocks
,
512
>>>
(
g
,
p
,
state1
,
unorm
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
break
;
...
...
@@ -175,20 +175,20 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
case
RMSPROP
:
case
ADAGRAD
:
CUDA_CHECK_RETURN
(
cudaMemset
(
new_max1
,
0
,
1
*
sizeof
(
float
)));
kPreconditionOptimizerStatic8bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
256
>>>
(
p
,
g
,
state1
,
unorm
,
beta1
,
eps
,
step
,
quantiles1
,
max1
,
new_max1
,
weight_decay
,
gnorm_scale
,
n
);
kPreconditionOptimizerStatic8bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
256
>>>
(
p
,
g
,
state1
,
unorm
,
beta1
,
beta2
,
eps
,
step
,
quantiles1
,
max1
,
new_max1
,
weight_decay
,
gnorm_scale
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
kOptimizerStatic8bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
1024
>>>
(
p
,
g
,
state1
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
eps
,
step
,
lr
,
kOptimizerStatic8bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
1024
>>>
(
p
,
g
,
state1
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
beta2
,
eps
,
step
,
lr
,
quantiles1
,
max1
,
new_max1
,
weight_decay
,
gnorm_scale
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
break
;
case
LION
:
// in lion, the momentum update happens after the parameter update
kOptimizerStatic8bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
1024
>>>
(
p
,
g
,
state1
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
beta2
,
step
,
lr
,
kOptimizerStatic8bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
1024
>>>
(
p
,
g
,
state1
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
beta2
,
eps
,
step
,
lr
,
quantiles1
,
max1
,
new_max1
,
weight_decay
,
gnorm_scale
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaMemset
(
new_max1
,
0
,
1
*
sizeof
(
float
)));
kPreconditionOptimizerStatic8bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
256
>>>
(
p
,
g
,
state1
,
unorm
,
beta1
,
beta2
,
step
,
quantiles1
,
max1
,
new_max1
,
weight_decay
,
gnorm_scale
,
n
);
kPreconditionOptimizerStatic8bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
256
>>>
(
p
,
g
,
state1
,
unorm
,
beta1
,
beta2
,
eps
,
step
,
quantiles1
,
max1
,
new_max1
,
weight_decay
,
gnorm_scale
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
break
;
default:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment