Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
c83888aa
Commit
c83888aa
authored
Mar 09, 2023
by
Phil Wang
Browse files
use epsilon as beta2 for lion, complete most of the logic in kernel.cu for all functions
parent
64bb1ae8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
18 deletions
+41
-18
bitsandbytes/optim/lion.py
bitsandbytes/optim/lion.py
+10
-7
csrc/kernels.cu
csrc/kernels.cu
+31
-11
No files found.
bitsandbytes/optim/lion.py
View file @
c83888aa
...
...
@@ -18,12 +18,13 @@ class Lion(Optimizer1State):
percentile_clipping
=
100
,
block_wise
=
True
,
):
beta1
,
beta2
=
betas
super
().
__init__
(
"lion"
,
params
,
lr
,
beta
s
,
0.
,
(
beta
1
,
0.
)
,
beta2
,
weight_decay
,
optim_bits
,
args
,
...
...
@@ -45,12 +46,13 @@ class Lion8bit(Optimizer1State):
percentile_clipping
=
100
,
block_wise
=
True
,
):
beta1
,
beta2
=
betas
super
().
__init__
(
"lion"
,
params
,
lr
,
beta
s
,
0.
,
(
beta
1
,
0.
)
,
beta2
,
weight_decay
,
8
,
args
,
...
...
@@ -72,12 +74,13 @@ class Lion32bit(Optimizer1State):
percentile_clipping
=
100
,
block_wise
=
True
,
):
beta1
,
beta2
=
betas
super
().
__init__
(
"lion"
,
params
,
lr
,
beta
s
,
0.
,
(
beta
1
,
0.
)
,
beta2
,
weight_decay
,
32
,
args
,
...
...
csrc/kernels.cu
View file @
c83888aa
...
...
@@ -43,6 +43,14 @@ __device__ float atomicMin(float* address, float val) {
return
__int_as_float
(
old
);
}
// sign function for lion
// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA
template
<
typename
T
>
__device__
int
sgn
(
T
val
)
{
return
(
T
(
0
)
<
val
)
-
(
val
<
T
(
0
));
}
template
<
int
STOCHASTIC
>
__device__
unsigned
char
dQuantize
(
float
*
smem_code
,
const
float
rand
,
float
x
)
{
...
...
@@ -217,14 +225,6 @@ __device__ __forceinline__ unsigned char quantize_quadrant(int QUADRANT, float *
}
}
// sign function for lion
// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA
template
<
typename
T
>
__device__
int
sgn
(
T
val
)
{
return
(
T
(
0
)
<
val
)
-
(
val
<
T
(
0
));
}
__global__
void
kHistogramScatterAdd2D
(
float
*
histogram
,
int
*
index1
,
int
*
index2
,
float
*
src
,
const
int
maxidx1
,
const
int
n
)
{
const
int
tid
=
threadIdx
.
x
+
(
blockDim
.
x
*
blockIdx
.
x
);
...
...
@@ -799,6 +799,10 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
s1_vals
[
j
]
=
s1_vals
[
j
]
*
s1_vals
[
j
];
// update norm
break
;
case
LION
:
// using eps as beta2
s1_vals
[
j
]
=
s1_vals
[
j
]
*
eps
+
((
1.0
f
-
eps
)
*
(
float
)
g_vals
[
j
]);
// state update
s1_vals
[
j
]
=
s1_vals
[
j
]
*
s1_vals
[
j
];
// update norm
break
;
case
RMSPROP
:
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
((
float
)
g_vals
[
j
])
*
((
float
)
g_vals
[
j
]));
// state update
s1_vals
[
j
]
=
__fdividef
((
float
)
g_vals
[
j
],
sqrtf
(
s1_vals
[
j
])
+
eps
);
// update value
...
...
@@ -900,6 +904,10 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
+
update_scale
*
(
-
lr
*
(
s1_vals
[
j
]));
break
;
case
LION
:
// using eps as beta2
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
update_scale
*
(
lr
*
sgn
(((
float
)
s1_vals
[
j
])
*
beta1
+
((
1.0
f
-
beta1
)
*
((
float
)
g_vals
[
j
]))));
s1_vals
[
j
]
=
s1_vals
[
j
]
*
eps
+
((
1.0
f
-
eps
)
*
((
float
)
g_vals
[
j
]));
break
;
case
RMSPROP
:
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
((
float
)
g_vals
[
j
])
*
((
float
)
g_vals
[
j
]));
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
update_scale
*
(
lr
*
__fdividef
((
float
)
g_vals
[
j
],
sqrtf
((
float
)
s1_vals
[
j
])
+
eps
));
...
...
@@ -1230,6 +1238,9 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
local_unorm
+=
s1_vals
[
j
]
*
s1_vals
[
j
];
break
;
case
LION
:
// using eps as beta2
s1_vals
[
j
]
=
s1_vals
[
j
]
*
eps
+
((
1.0
f
-
eps
)
*
g_val
);
break
;
case
RMSPROP
:
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
(
g_val
*
g_val
));
break
;
...
...
@@ -1333,6 +1344,10 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
+
(
-
lr
*
update_scale
*
(
s1_vals
[
j
]));
break
;
case
LION
:
// using eps as beta2
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
(
lr
*
sgn
(((
float
)
s1_vals
[
j
])
*
beta1
+
((
1.0
f
-
beta1
)
*
((
float
)
g_val
))));
s1_vals
[
j
]
=
s1_vals
[
j
]
*
eps
+
((
1.0
f
-
eps
)
*
g_val
);
break
;
case
RMSPROP
:
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
(
g_val
*
g_val
));
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
(
lr
*
__fdividef
(
g_val
,
sqrtf
(
s1_vals
[
j
])
+
eps
));
...
...
@@ -1677,6 +1692,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
s1_vals
[
j
]
=
(
s1_vals
[
j
]
*
beta1
)
+
g_val
;
break
;
case
LION
:
// using eps as beta2
s1_vals
[
j
]
=
s1_vals
[
j
]
*
eps
+
((
1.0
f
-
eps
)
*
g_val
);
break
;
case
RMSPROP
:
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
(
g_val
*
g_val
));
break
;
...
...
@@ -1715,6 +1733,8 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
lr
*
(
s1_vals
[
j
]);
break
;
case
LION
:
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
lr
*
sgn
(((
float
)
s1_vals
[
j
])
*
beta1
+
((
1.0
f
-
beta1
)
*
((
float
)
g_vals
[
j
])));
break
;
case
RMSPROP
:
g_val
=
g_vals
[
j
];
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
lr
*
(
__fdividef
(
g_val
,
sqrtf
(
s1_vals
[
j
])
+
eps
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment