Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
bb34fd50
Commit
bb34fd50
authored
Oct 20, 2021
by
Tim Dettmers
Browse files
Initial plumbing for skip_zeros.
parent
8400b58c
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
86 additions
and
42 deletions
+86
-42
bitsandbytes/functional.py
bitsandbytes/functional.py
+19
-6
bitsandbytes/optim/optimizer.py
bitsandbytes/optim/optimizer.py
+10
-4
csrc/kernels.cu
csrc/kernels.cu
+9
-9
csrc/kernels.cuh
csrc/kernels.cuh
+4
-4
csrc/ops.cu
csrc/ops.cu
+8
-8
csrc/ops.cuh
csrc/ops.cuh
+3
-2
csrc/pythonInterface.c
csrc/pythonInterface.c
+9
-9
tests/test_optim.py
tests/test_optim.py
+24
-0
No files found.
bitsandbytes/functional.py
View file @
bb34fd50
...
...
@@ -337,7 +337,7 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten
beta1
:
float
,
eps
:
float
,
step
:
int
,
lr
:
float
,
state2
:
Tensor
=
None
,
beta2
:
float
=
0.0
,
weight_decay
:
float
=
0.0
,
gnorm_scale
:
float
=
1.0
,
unorm_vec
:
Tensor
=
None
,
max_unorm
:
float
=
0.0
)
->
None
:
unorm_vec
:
Tensor
=
None
,
max_unorm
:
float
=
0.0
,
skip_zeros
=
False
)
->
None
:
'''
Performs an inplace optimizer update with one or two optimizer states.
...
...
@@ -369,6 +369,12 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten
Optimizer beta2.
gnorm_scale : float
The factor to rescale the gradient to the max clip value.
unorm_vec : torch.Tensor
The tensor for the update norm.
max_unorm : float
The maximum update norm relative to the weight norm.
skip_zeros : bool
Whether to skip zero-valued gradients or not (default: False).
'''
param_norm
=
0.0
...
...
@@ -381,11 +387,11 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
float32
:
str2optimizer32bit
[
optimizer_name
][
0
](
get_ptr
(
g
),
get_ptr
(
p
),
get_ptr
(
state1
),
get_ptr
(
state2
),
get_ptr
(
unorm_vec
),
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_float
(
weight_decay
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_int32
(
g
.
numel
()))
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()))
elif
g
.
dtype
==
torch
.
float16
and
state1
.
dtype
==
torch
.
float32
:
str2optimizer32bit
[
optimizer_name
][
1
](
get_ptr
(
g
),
get_ptr
(
p
),
get_ptr
(
state1
),
get_ptr
(
state2
),
get_ptr
(
unorm_vec
),
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_float
(
weight_decay
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_int32
(
g
.
numel
()))
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()))
else
:
raise
ValueError
(
f
'Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
'
)
...
...
@@ -439,6 +445,10 @@ def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Ten
Max value for the next Adam update of the second state.
gnorm_scale : float
The factor to rescale the gradient to the max clip value.
unorm_vec : torch.Tensor
The tensor for the update norm.
max_unorm : float
The maximum update norm relative to the weight norm.
'''
param_norm
=
0.0
...
...
@@ -468,19 +478,22 @@ def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Ten
def
optimizer_update_8bit_blockwise
(
optimizer_name
:
str
,
g
:
Tensor
,
p
:
Tensor
,
state1
:
Tensor
,
state2
:
Tensor
,
beta1
:
float
,
beta2
:
float
,
eps
:
float
,
step
:
int
,
lr
:
float
,
qmap1
:
Tensor
,
qmap2
:
Tensor
,
absmax1
:
Tensor
,
absmax2
:
Tensor
,
weight_decay
:
float
=
0.0
,
gnorm_scale
:
float
=
1.0
)
->
None
:
absmax1
:
Tensor
,
absmax2
:
Tensor
,
weight_decay
:
float
=
0.0
,
gnorm_scale
:
float
=
1.0
,
skip_zeros
=
False
)
->
None
:
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
uint8
:
str2optimizer8bit_blockwise
[
optimizer_name
][
0
](
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state2
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
absmax1
),
get_ptr
(
absmax2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_int32
(
g
.
numel
()))
get_ptr
(
absmax1
),
get_ptr
(
absmax2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()))
elif
g
.
dtype
==
torch
.
float16
and
state1
.
dtype
==
torch
.
uint8
:
str2optimizer8bit_blockwise
[
optimizer_name
][
1
](
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state2
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
absmax1
),
get_ptr
(
absmax2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_int32
(
g
.
numel
()))
get_ptr
(
absmax1
),
get_ptr
(
absmax2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()))
else
:
raise
ValueError
(
f
'Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
'
)
...
...
bitsandbytes/optim/optimizer.py
View file @
bb34fd50
...
...
@@ -220,6 +220,7 @@ class Optimizer8bit(torch.optim.Optimizer):
config
[
'percentile_clipping'
]
=
self
.
args
.
percentile_clipping
config
[
'block_wise'
]
=
self
.
args
.
block_wise
config
[
'max_unorm'
]
=
self
.
args
.
max_unorm
config
[
'skip_zeros'
]
=
self
.
args
.
skip_zeros
if
(
gindex
,
pindex
)
in
self
.
mng
.
index2config
:
config
.
update
(
self
.
mng
.
index2config
[(
gindex
,
pindex
)])
...
...
@@ -234,7 +235,8 @@ class Optimizer8bit(torch.optim.Optimizer):
class
Optimizer2State
(
Optimizer8bit
):
def
__init__
(
self
,
optimizer_name
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0.0
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
max_unorm
=
0.0
):
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
max_unorm
=
0.0
,
skip_zeros
=
False
):
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
eps
:
...
...
@@ -259,6 +261,7 @@ class Optimizer2State(Optimizer8bit):
args
[
'percentile_clipping'
]
=
percentile_clipping
args
[
'block_wise'
]
=
block_wise
args
[
'max_unorm'
]
=
max_unorm
args
[
'skip_zeros'
]
=
skip_zeros
self
.
args
=
MockArgs
(
args
)
else
:
...
...
@@ -355,7 +358,8 @@ class Optimizer2State(Optimizer8bit):
class
Optimizer1State
(
Optimizer8bit
):
def
__init__
(
self
,
optimizer_name
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.0
),
eps
=
1e-8
,
weight_decay
=
0.0
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
max_unorm
=
0.0
):
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
max_unorm
=
0.0
,
skip_zeros
=
False
):
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
eps
:
...
...
@@ -377,6 +381,7 @@ class Optimizer1State(Optimizer8bit):
args
[
'percentile_clipping'
]
=
percentile_clipping
args
[
'block_wise'
]
=
block_wise
args
[
'max_unorm'
]
=
max_unorm
args
[
'skip_zeros'
]
=
skip_zeros
self
.
args
=
MockArgs
(
args
)
else
:
...
...
@@ -444,7 +449,8 @@ class Optimizer1State(Optimizer8bit):
if
state
[
'state1'
].
dtype
==
torch
.
float
:
F
.
optimizer_update_32bit
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
config
[
'betas'
][
0
],
config
[
'eps'
],
step
,
config
[
'lr'
],
None
,
0.0
,
config
[
'weight_decay'
],
gnorm_scale
,
state
[
'unorm_vec'
]
if
config
[
'max_unorm'
]
>
0.0
else
None
,
max_unorm
=
config
[
'max_unorm'
])
state
[
'unorm_vec'
]
if
config
[
'max_unorm'
]
>
0.0
else
None
,
max_unorm
=
config
[
'max_unorm'
],
skip_zeros
=
False
)
elif
state
[
'state1'
].
dtype
==
torch
.
uint8
and
not
config
[
'block_wise'
]:
F
.
optimizer_update_8bit
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
None
,
config
[
'betas'
][
0
],
config
[
'betas'
][
1
],
...
...
@@ -457,4 +463,4 @@ class Optimizer1State(Optimizer8bit):
F
.
optimizer_update_8bit_blockwise
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
None
,
config
[
'betas'
][
0
],
config
[
'betas'
][
1
],
config
[
'eps'
],
step
,
config
[
'lr'
],
state
[
'qmap1'
],
None
,
state
[
'absmax1'
],
None
,
config
[
'weight_decay'
],
gnorm_scale
=
gnorm_scale
)
config
[
'weight_decay'
],
gnorm_scale
=
gnorm_scale
,
skip_zeros
=
False
)
csrc/kernels.cu
View file @
bb34fd50
...
...
@@ -654,7 +654,7 @@ __launch_bounds__(TH, 1)
__global__
void
kOptimizer32bit2State
(
T
*
g
,
T
*
p
,
float
*
state1
,
float
*
state2
,
float
*
unorm
,
const
float
max_unorm
,
const
float
param_norm
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
float
weight_decay
,
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
int
n
)
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
bool
skip_zeros
,
const
int
n
)
{
const
int
n_full
=
((
TH
*
NUM_PER_THREAD
)
*
(
n
/
(
TH
*
NUM_PER_THREAD
)))
+
(
n
%
(
TH
*
NUM_PER_THREAD
)
==
0
?
0
:
(
TH
*
NUM_PER_THREAD
));
...
...
@@ -809,7 +809,7 @@ __launch_bounds__(TH, 1)
__global__
void
kOptimizer32bit1State
(
T
*
g
,
T
*
p
,
float
*
state1
,
float
*
unorm
,
const
float
max_unorm
,
const
float
param_norm
,
const
float
beta1
,
const
float
eps
,
const
float
weight_decay
,
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
int
n
)
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
bool
skip_zeros
,
const
int
n
)
{
const
int
n_full
=
((
TH
*
NUM_PER_THREAD
)
*
(
n
/
(
TH
*
NUM_PER_THREAD
)))
+
(
n
%
(
TH
*
NUM_PER_THREAD
)
==
0
?
0
:
(
TH
*
NUM_PER_THREAD
));
...
...
@@ -1383,7 +1383,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
float
*
__restrict__
const
quantiles1
,
float
*
__restrict__
const
quantiles2
,
float
*
absmax1
,
float
*
absmax2
,
float
weight_decay
,
const
float
gnorm_scale
,
const
int
n
)
const
float
gnorm_scale
,
const
bool
skip_zeros
,
const
int
n
)
{
//const int n_full = n + (n%BLOCK_SIZE);
...
...
@@ -1555,7 +1555,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
float
*
__restrict__
const
quantiles1
,
float
*
absmax1
,
float
weight_decay
,
const
float
gnorm_scale
,
const
int
n
)
const
float
gnorm_scale
,
const
bool
skip_zeros
,
const
int
n
)
{
//const int n_full = n + (n%BLOCK_SIZE);
...
...
@@ -1723,7 +1723,7 @@ MAKE_PreconditionOptimizer32bit1State(RMSPROP, float)
#define MAKE_Optimizer32bit1State(oname, gtype) \
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
const float beta1, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const int n); \
const float beta1, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const
bool skip_zeros, const
int n); \
MAKE_Optimizer32bit1State
(
MOMENTUM
,
half
)
MAKE_Optimizer32bit1State
(
MOMENTUM
,
float
)
...
...
@@ -1740,9 +1740,9 @@ MAKE_PreconditionOptimizer32bit2State(ADAM, half)
MAKE_PreconditionOptimizer32bit2State
(
ADAM
,
float
)
template
__global__
void
kOptimizer32bit2State
<
half
,
ADAM
>(
half
*
g
,
half
*
p
,
float
*
state1
,
float
*
state2
,
float
*
unorm
,
const
float
max_unorm
,
const
float
param_norm
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
float
weight_decay
,
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
int
n
);
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
float
weight_decay
,
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
bool
skip_zeros
,
const
int
n
);
template
__global__
void
kOptimizer32bit2State
<
float
,
ADAM
>(
float
*
g
,
float
*
p
,
float
*
state1
,
float
*
state2
,
float
*
unorm
,
const
float
max_unorm
,
const
float
param_norm
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
float
weight_decay
,
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
int
n
);
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
float
weight_decay
,
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
bool
skip_zeros
,
const
int
n
);
#define MAKE_PreconditionStatic8bit1State(oname, gtype) \
template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \
...
...
@@ -1825,7 +1825,7 @@ template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \
float* absmax1, float* absmax2, \
float weight_decay, \
const float gnorm_scale, const int n); \
const float gnorm_scale, const
bool skip_zeros, const
int n); \
MAKE_OptimizerStatic8bit2StateBlockwise
(
ADAM
,
float
,
2048
,
8
)
MAKE_OptimizerStatic8bit2StateBlockwise
(
ADAM
,
half
,
2048
,
8
)
...
...
@@ -1838,7 +1838,7 @@ template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block
float* __restrict__ const quantiles1, \
float* absmax1, \
float weight_decay, \
const float gnorm_scale, const int n); \
const float gnorm_scale, const
bool skip_zeros, const
int n); \
MAKE_OptimizerStatic8bit1StateBlockwise
(
MOMENTUM
,
float
,
2048
,
8
)
MAKE_OptimizerStatic8bit1StateBlockwise
(
MOMENTUM
,
half
,
2048
,
8
)
...
...
csrc/kernels.cuh
View file @
bb34fd50
...
...
@@ -27,7 +27,7 @@ template<typename T, int OPTIMIZER>
__global__
void
kOptimizer32bit2State
(
T
*
g
,
T
*
p
,
float
*
state1
,
float
*
state2
,
float
*
unorm
,
const
float
max_unorm
,
const
float
param_norm
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
float
weight_decay
,
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
int
n
);
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
bool
skip_zeros
,
const
int
n
);
template
<
typename
T
,
int
OPTIMIZER
,
int
BLOCK_SIZE
,
int
NUM_VALS
>
__global__
void
kPreconditionOptimizer32bit1State
(
T
*
g
,
T
*
p
,
...
...
@@ -39,7 +39,7 @@ template<typename T, int OPTIMIZER>
__global__
void
kOptimizer32bit1State
(
T
*
g
,
T
*
p
,
float
*
state1
,
float
*
unorm
,
const
float
max_unorm
,
const
float
param_norm
,
const
float
beta1
,
const
float
eps
,
const
float
weight_decay
,
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
int
n
);
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
bool
skip_zeros
,
const
int
n
);
template
<
typename
T
,
int
OPTIMIZER
>
__global__
void
...
...
@@ -90,7 +90,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ voi
T
*
p
,
T
*
__restrict__
const
g
,
unsigned
char
*
state1
,
unsigned
char
*
state2
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
int
step
,
const
float
lr
,
float
*
__restrict__
const
quantiles1
,
float
*
__restrict__
const
quantiles2
,
float
*
absmax1
,
float
*
absmax2
,
float
weight_decay
,
const
float
gnorm_scale
,
const
int
n
);
float
*
absmax1
,
float
*
absmax2
,
float
weight_decay
,
const
float
gnorm_scale
,
const
bool
skip_zeros
,
const
int
n
);
template
<
typename
T
,
int
OPTIMIZER
,
int
BLOCK_SIZE
,
int
N_PER_TH
>
__global__
void
kOptimizerStatic8bit1StateBlockwise
(
T
*
p
,
T
*
__restrict__
const
g
,
unsigned
char
*
state1
,
...
...
@@ -99,7 +99,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ voi
float
*
__restrict__
const
quantiles1
,
float
*
absmax1
,
float
weight_decay
,
const
float
gnorm_scale
,
const
int
n
);
const
float
gnorm_scale
,
const
bool
skip_zeros
,
const
int
n
);
template
<
typename
T
,
int
BLOCK_SIZE
,
int
NUM_VALS
>
__global__
void
kPercentileClipping
(
T
*
__restrict__
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
);
...
...
csrc/ops.cu
View file @
bb34fd50
...
...
@@ -181,7 +181,7 @@ template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, flo
template
<
typename
T
,
int
OPTIMIZER
>
void
optimizer32bit
(
T
*
g
,
T
*
p
,
float
*
state1
,
float
*
state2
,
float
*
unorm
,
float
max_unorm
,
float
param_norm
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
float
weight_decay
,
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
int
n
)
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
bool
skip_zeros
,
const
int
n
)
{
int
blocks
=
n
/
4096
;
blocks
=
n
%
4096
==
0
?
blocks
:
blocks
+
1
;
...
...
@@ -194,7 +194,7 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
kPreconditionOptimizer32bit2State
<
T
,
OPTIMIZER
,
4096
,
8
><<<
blocks
,
512
>>>
(
g
,
p
,
state1
,
state2
,
unorm
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
kOptimizer32bit2State
<
T
,
OPTIMIZER
><<<
blocks
,
1024
>>>
(
g
,
p
,
state1
,
state2
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
n
);
kOptimizer32bit2State
<
T
,
OPTIMIZER
><<<
blocks
,
1024
>>>
(
g
,
p
,
state1
,
state2
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
skip_zeros
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
break
;
case
MOMENTUM
:
...
...
@@ -206,7 +206,7 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
kOptimizer32bit1State
<
T
,
OPTIMIZER
><<<
blocks
,
1024
>>>
(
g
,
p
,
state1
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
n
);
kOptimizer32bit1State
<
T
,
OPTIMIZER
><<<
blocks
,
1024
>>>
(
g
,
p
,
state1
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
skip_zeros
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
break
;
}
...
...
@@ -259,7 +259,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
template
<
typename
T
,
int
OPTIMIZER
>
void
optimizerStatic8bitBlockwise
(
T
*
p
,
T
*
g
,
unsigned
char
*
state1
,
unsigned
char
*
state2
,
float
beta1
,
float
beta2
,
float
eps
,
int
step
,
float
lr
,
float
*
quantiles1
,
float
*
quantiles2
,
float
*
absmax1
,
float
*
absmax2
,
float
weight_decay
,
const
float
gnorm_scale
,
int
n
)
float
*
quantiles1
,
float
*
quantiles2
,
float
*
absmax1
,
float
*
absmax2
,
float
weight_decay
,
const
float
gnorm_scale
,
bool
skip_zeros
,
int
n
)
{
int
blocks
=
0
;
...
...
@@ -269,7 +269,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
blocks
=
n
/
BLOCKSIZE_2STATE
;
blocks
=
n
%
BLOCKSIZE_2STATE
==
0
?
blocks
:
blocks
+
1
;
kOptimizerStatic8bit2StateBlockwise
<
T
,
OPTIMIZER
,
BLOCKSIZE_2STATE
,
NUM_2STATE
><<<
blocks
,
BLOCKSIZE_2STATE
/
NUM_2STATE
>>>
(
p
,
g
,
state1
,
state2
,
beta1
,
beta2
,
eps
,
step
,
lr
,
quantiles1
,
quantiles2
,
absmax1
,
absmax2
,
weight_decay
,
gnorm_scale
,
n
);
quantiles1
,
quantiles2
,
absmax1
,
absmax2
,
weight_decay
,
gnorm_scale
,
skip_zeros
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
break
;
case
MOMENTUM
:
...
...
@@ -277,7 +277,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
blocks
=
n
/
BLOCKSIZE_1STATE
;
blocks
=
n
%
BLOCKSIZE_1STATE
==
0
?
blocks
:
blocks
+
1
;
kOptimizerStatic8bit1StateBlockwise
<
T
,
OPTIMIZER
,
BLOCKSIZE_1STATE
,
NUM_1STATE
><<<
blocks
,
BLOCKSIZE_1STATE
/
NUM_1STATE
>>>
(
p
,
g
,
state1
,
beta1
,
beta2
,
eps
,
step
,
lr
,
quantiles1
,
absmax1
,
weight_decay
,
gnorm_scale
,
n
);
quantiles1
,
absmax1
,
weight_decay
,
gnorm_scale
,
skip_zeros
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
break
;
}
...
...
@@ -313,7 +313,7 @@ template void dequantizeBlockwise<float>(float *code, unsigned char *A, float *a
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, const int n);
const int step, const float lr, const float gnorm_scale, const
bool skip_zeros, const
int n);
MAKE_optimizer32bit
(
ADAM
,
half
)
MAKE_optimizer32bit
(
ADAM
,
float
)
...
...
@@ -342,7 +342,7 @@ MAKE_optimizerStatic8bit(RMSPROP, float)
#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n); \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale,
bool skip_zeros,
int n); \
MAKE_optimizerStatic8bitBlockwise
(
half
,
ADAM
);
MAKE_optimizerStatic8bitBlockwise
(
float
,
ADAM
);
...
...
csrc/ops.cuh
View file @
bb34fd50
...
...
@@ -49,7 +49,7 @@ template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, flo
template
<
typename
T
,
int
OPTIMIZER
>
void
optimizer32bit
(
T
*
g
,
T
*
p
,
float
*
state1
,
float
*
state2
,
float
*
unorm
,
float
max_unorm
,
float
param_norm
,
float
beta1
,
float
beta2
,
float
eps
,
float
weight_decay
,
int
step
,
float
lr
,
const
float
gnorm_scale
,
int
n
);
int
step
,
float
lr
,
const
float
gnorm_scale
,
bool
skip_zeros
,
int
n
);
template
<
typename
T
,
int
OPTIMIZER
>
void
optimizerStatic8bit
(
T
*
p
,
T
*
g
,
unsigned
char
*
state1
,
unsigned
char
*
state2
,
float
*
unorm
,
float
max_unorm
,
float
param_norm
,
...
...
@@ -62,7 +62,8 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, unsigne
template
<
typename
T
,
int
OPTIMIZER
>
void
optimizerStatic8bitBlockwise
(
T
*
p
,
T
*
g
,
unsigned
char
*
state1
,
unsigned
char
*
state2
,
float
beta1
,
float
beta2
,
float
eps
,
int
step
,
float
lr
,
float
*
quantiles1
,
float
*
quantiles2
,
float
*
absmax1
,
float
*
absmax2
,
float
weight_decay
,
const
float
gnorm_scale
,
int
n
);
float
*
quantiles1
,
float
*
quantiles2
,
float
*
absmax1
,
float
*
absmax2
,
float
weight_decay
,
const
float
gnorm_scale
,
bool
skip_zeros
,
int
n
);
template
<
typename
T
>
void
percentileClipping
(
T
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
);
...
...
csrc/pythonInterface.c
View file @
bb34fd50
...
...
@@ -20,8 +20,8 @@ void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimate
void fname##32bit_g##gbits(gtype *g, gtype *p, \
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float eps, const float weight_decay, \
const int step, const float lr, float gnorm_scale, const int n) \
{ optimizer32bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); } \
const int step, const float lr, float gnorm_scale,
bool skip_zeros,
const int n) \
{ optimizer32bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale,
skip_zeros,
n); } \
MAKE_FUNC32
(
momentum
,
MOMENTUM
,
float
,
32
)
MAKE_FUNC32
(
momentum
,
MOMENTUM
,
half
,
16
)
...
...
@@ -53,8 +53,8 @@ MAKE_FUNC8(rmsprop, RMSPROP, half, 16)
#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n)\
{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, n); }\
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale,
bool skip_zeros,
int n)\
{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale,
skip_zeros,
n); }\
MAKE_BLOCKWISE8
(
adam
,
ADAM
,
half
,
16
)
MAKE_BLOCKWISE8
(
adam
,
ADAM
,
float
,
32
)
...
...
@@ -93,8 +93,8 @@ extern "C"
void c##name##32bit_g##gbits(gtype *g, gtype *p, \
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, const int n) \
{ name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); } \
const int step, const float lr, const float gnorm_scale,
bool skip_zeros,
const int n) \
{ name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale,
skip_zeros,
n); } \
MAKE_CFUNC32
(
adam
,
float
,
32
)
MAKE_CFUNC32
(
adam
,
half
,
16
)
...
...
@@ -110,7 +110,7 @@ extern "C"
float eps, int step, float lr, \
float* quantiles1, float* quantiles2, \
float* max1, float* max2, float* new_max1, float* new_max2, \
float weight_decay, float gnorm_scale, int n) \
float weight_decay, float gnorm_scale,
bool skip_zeros,
int n) \
{ \
name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
...
...
@@ -126,8 +126,8 @@ extern "C"
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n) \
{ fname##_8bit_blockwise_fp##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, n); } \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale,
bool skip_zeros,
int n) \
{ fname##_8bit_blockwise_fp##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale,
skip_zeros,
n); } \
MAKE_CBLOCKWISE8
(
adam
,
ADAM
,
half
,
16
)
MAKE_CBLOCKWISE8
(
adam
,
ADAM
,
float
,
32
)
...
...
tests/test_optim.py
View file @
bb34fd50
...
...
@@ -141,6 +141,7 @@ def test_global_config(dim1, dim2, gtype):
eps
=
1e-8
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
initialize
()
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
override_config
(
p2
,
'skip_zeros'
,
True
)
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
override_config
(
p3
,
'optim_bits'
,
8
)
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
register_parameters
([
p1
,
p2
,
p3
])
...
...
@@ -155,6 +156,8 @@ def test_global_config(dim1, dim2, gtype):
else
:
atol
,
rtol
=
1e-4
,
1e-3
original_p2
=
p2
[
mask
].
clone
()
for
i
in
range
(
50
):
g1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.1
+
0.001
g2
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.1
+
0.001
...
...
@@ -163,11 +166,32 @@ def test_global_config(dim1, dim2, gtype):
p2
.
grad
=
g2
p3
.
grad
=
g3
if
i
>
30
and
i
%
10
==
0
:
g1
.
data
[
mask
]
=
0.0
g2
.
data
[
mask
]
=
0.0
p1
.
grad
=
g1
p2
.
grad
=
g2
original_p1
=
p1
[
mask
].
clone
()
original_p2
=
p2
[
mask
].
clone
()
og_s1
=
adam2
.
state
[
p2
][
'state1'
][
mask
].
clone
()
og_s2
=
adam2
.
state
[
p2
][
'state2'
][
mask
].
clone
()
og_s11
=
adam2
.
state
[
p1
][
'state1'
][
mask
].
clone
()
og_s21
=
adam2
.
state
[
p1
][
'state2'
][
mask
].
clone
()
adam2
.
step
()
assert
adam2
.
state
[
p3
][
'state1'
].
dtype
==
torch
.
uint8
assert
adam2
.
state
[
p3
][
'state2'
].
dtype
==
torch
.
uint8
if
i
>
30
and
i
%
10
==
0
:
torch
.
testing
.
assert_allclose
(
original_p2
,
p2
[
mask
])
torch
.
testing
.
assert_allclose
(
adam2
.
state
[
p2
][
'state1'
][
mask
],
og_s1
)
torch
.
testing
.
assert_allclose
(
adam2
.
state
[
p2
][
'state2'
][
mask
],
og_s2
)
assert
((
p1
[
mask
]
-
original_p1
)
==
0.0
).
sum
()
<
p1
.
numel
()
assert
((
adam2
.
state
[
p1
][
'state1'
][
mask
]
-
og_s11
)
==
0.0
).
sum
()
==
0.0
assert
((
adam2
.
state
[
p1
][
'state2'
][
mask
]
-
og_s21
)
==
0.0
).
sum
()
==
0.0
dim1
=
[
1024
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment