Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
a6eae2e7
Commit
a6eae2e7
authored
Oct 20, 2021
by
Tim Dettmers
Browse files
Added skip_zeros; tests are passing.
parent
bb34fd50
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
102 additions
and
82 deletions
+102
-82
Makefile
Makefile
+25
-23
bitsandbytes/functional.py
bitsandbytes/functional.py
+2
-2
bitsandbytes/optim/optimizer.py
bitsandbytes/optim/optimizer.py
+4
-4
csrc/kernels.cu
csrc/kernels.cu
+70
-52
csrc/pythonInterface.c
csrc/pythonInterface.c
+1
-1
No files found.
Makefile
View file @
a6eae2e7
...
@@ -15,29 +15,31 @@ INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/inclu
...
@@ -15,29 +15,31 @@ INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/inclu
LIB
:=
-L
$(CUDA_HOME)
/lib64
-lcudart
-lcuda
-lcublas
-lcurand
-lcusparse
-L
$(CONDA_PREFIX)
/lib
LIB
:=
-L
$(CUDA_HOME)
/lib64
-lcudart
-lcuda
-lcublas
-lcurand
-lcusparse
-L
$(CONDA_PREFIX)
/lib
# NVIDIA NVCC compilation flags
# NVIDIA NVCC compilation flags
COMPUTE_CAPABILITY
:=
-gencode
arch
=
compute_35,code
=
sm_35
# Kepler
#COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler
COMPUTE_CAPABILITY
+=
-gencode
arch
=
compute_37,code
=
sm_37
# Kepler
#COMPUTE_CAPABILITY += -gencode arch=compute_37,code=sm_37 # Kepler
COMPUTE_CAPABILITY
+=
-gencode
arch
=
compute_50,code
=
sm_50
# Maxwell
#COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell
COMPUTE_CAPABILITY
+=
-gencode
arch
=
compute_52,code
=
sm_52
# Maxwell
#COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell
COMPUTE_CAPABILITY
+=
-gencode
arch
=
compute_60,code
=
sm_60
# Pascal
#COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal
COMPUTE_CAPABILITY
+=
-gencode
arch
=
compute_61,code
=
sm_61
# Pascal
#COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
COMPUTE_CAPABILITY
+=
-gencode
arch
=
compute_70,code
=
sm_70
# Volta
#COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
COMPUTE_CAPABILITY
+=
-gencode
arch
=
compute_72,code
=
sm_72
# Volta
#COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
COMPUTE_CAPABILITY
+=
-gencode
arch
=
compute_72,code
=
sm_72
# Volta
#COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
#
# CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not
## CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not
CC_CUDA92
:=
-gencode
arch
=
compute_30,code
=
sm_30
#CC_CUDA92 := -gencode arch=compute_30,code=sm_30
#
# Later versions of CUDA support the new architectures
## Later versions of CUDA support the new architectures
CC_CUDA10x
:=
-gencode
arch
=
compute_30,code
=
sm_30
#CC_CUDA10x := -gencode arch=compute_30,code=sm_30
CC_CUDA10x
+=
-gencode
arch
=
compute_75,code
=
sm_75
#CC_CUDA10x += -gencode arch=compute_75,code=sm_75
#
CC_CUDA110
:=
-gencode
arch
=
compute_75,code
=
sm_75
#CC_CUDA110 := -gencode arch=compute_75,code=sm_75
CC_CUDA110
+=
-gencode
arch
=
compute_80,code
=
sm_80
#CC_CUDA110 += -gencode arch=compute_80,code=sm_80
#
CC_CUDA11x
:=
-gencode
arch
=
compute_75,code
=
sm_75
#CC_CUDA11x := -gencode arch=compute_75,code=sm_75
CC_CUDA11x
+=
-gencode
arch
=
compute_80,code
=
sm_80
#CC_CUDA11x += -gencode arch=compute_80,code=sm_80
CC_CUDA11x
+=
-gencode
arch
=
compute_86,code
=
sm_86
#CC_CUDA11x += -gencode arch=compute_86,code=sm_86
COMPUTE_CAPABILITY
:=
-gencode
arch
=
compute_70,code
=
sm_70
# Volta
all
:
$(ROOT_DIR)/dependencies/cub $(BUILD_DIR)
all
:
$(ROOT_DIR)/dependencies/cub $(BUILD_DIR)
...
...
bitsandbytes/functional.py
View file @
a6eae2e7
...
@@ -486,13 +486,13 @@ def optimizer_update_8bit_blockwise(optimizer_name: str, g: Tensor, p: Tensor, s
...
@@ -486,13 +486,13 @@ def optimizer_update_8bit_blockwise(optimizer_name: str, g: Tensor, p: Tensor, s
str2optimizer8bit_blockwise
[
optimizer_name
][
0
](
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state2
),
str2optimizer8bit_blockwise
[
optimizer_name
][
0
](
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state2
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
absmax1
),
get_ptr
(
absmax2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
get_ptr
(
absmax1
),
get_ptr
(
absmax2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()))
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()))
elif
g
.
dtype
==
torch
.
float16
and
state1
.
dtype
==
torch
.
uint8
:
elif
g
.
dtype
==
torch
.
float16
and
state1
.
dtype
==
torch
.
uint8
:
str2optimizer8bit_blockwise
[
optimizer_name
][
1
](
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state2
),
str2optimizer8bit_blockwise
[
optimizer_name
][
1
](
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state2
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
absmax1
),
get_ptr
(
absmax2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
get_ptr
(
absmax1
),
get_ptr
(
absmax2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()))
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()))
else
:
else
:
raise
ValueError
(
f
'Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
'
)
raise
ValueError
(
f
'Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
'
)
...
...
bitsandbytes/optim/optimizer.py
View file @
a6eae2e7
...
@@ -336,7 +336,7 @@ class Optimizer2State(Optimizer8bit):
...
@@ -336,7 +336,7 @@ class Optimizer2State(Optimizer8bit):
if
state
[
'state1'
].
dtype
==
torch
.
float
:
if
state
[
'state1'
].
dtype
==
torch
.
float
:
F
.
optimizer_update_32bit
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
config
[
'betas'
][
0
],
config
[
'eps'
],
step
,
config
[
'lr'
],
F
.
optimizer_update_32bit
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
config
[
'betas'
][
0
],
config
[
'eps'
],
step
,
config
[
'lr'
],
state
[
'state2'
],
config
[
'betas'
][
1
],
config
[
'weight_decay'
],
gnorm_scale
,
state
[
'state2'
],
config
[
'betas'
][
1
],
config
[
'weight_decay'
],
gnorm_scale
,
state
[
'unorm_vec'
]
if
config
[
'max_unorm'
]
>
0.0
else
None
,
max_unorm
=
config
[
'max_unorm'
])
state
[
'unorm_vec'
]
if
config
[
'max_unorm'
]
>
0.0
else
None
,
max_unorm
=
config
[
'max_unorm'
]
,
skip_zeros
=
config
[
'skip_zeros'
]
)
elif
state
[
'state1'
].
dtype
==
torch
.
uint8
and
not
config
[
'block_wise'
]:
elif
state
[
'state1'
].
dtype
==
torch
.
uint8
and
not
config
[
'block_wise'
]:
F
.
optimizer_update_8bit
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
state
[
'state2'
],
config
[
'betas'
][
0
],
config
[
'betas'
][
1
],
F
.
optimizer_update_8bit
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
state
[
'state2'
],
config
[
'betas'
][
0
],
config
[
'betas'
][
1
],
...
@@ -352,7 +352,7 @@ class Optimizer2State(Optimizer8bit):
...
@@ -352,7 +352,7 @@ class Optimizer2State(Optimizer8bit):
F
.
optimizer_update_8bit_blockwise
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
state
[
'state2'
],
config
[
'betas'
][
0
],
config
[
'betas'
][
1
],
F
.
optimizer_update_8bit_blockwise
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
state
[
'state2'
],
config
[
'betas'
][
0
],
config
[
'betas'
][
1
],
config
[
'eps'
],
step
,
config
[
'lr'
],
config
[
'eps'
],
step
,
config
[
'lr'
],
state
[
'qmap1'
],
state
[
'qmap2'
],
state
[
'absmax1'
],
state
[
'absmax2'
],
state
[
'qmap1'
],
state
[
'qmap2'
],
state
[
'absmax1'
],
state
[
'absmax2'
],
config
[
'weight_decay'
],
gnorm_scale
=
gnorm_scale
)
config
[
'weight_decay'
],
gnorm_scale
=
gnorm_scale
,
skip_zeros
=
config
[
'skip_zeros'
]
)
class
Optimizer1State
(
Optimizer8bit
):
class
Optimizer1State
(
Optimizer8bit
):
...
@@ -450,7 +450,7 @@ class Optimizer1State(Optimizer8bit):
...
@@ -450,7 +450,7 @@ class Optimizer1State(Optimizer8bit):
F
.
optimizer_update_32bit
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
config
[
'betas'
][
0
],
config
[
'eps'
],
step
,
config
[
'lr'
],
F
.
optimizer_update_32bit
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
config
[
'betas'
][
0
],
config
[
'eps'
],
step
,
config
[
'lr'
],
None
,
0.0
,
config
[
'weight_decay'
],
gnorm_scale
,
None
,
0.0
,
config
[
'weight_decay'
],
gnorm_scale
,
state
[
'unorm_vec'
]
if
config
[
'max_unorm'
]
>
0.0
else
None
,
max_unorm
=
config
[
'max_unorm'
],
state
[
'unorm_vec'
]
if
config
[
'max_unorm'
]
>
0.0
else
None
,
max_unorm
=
config
[
'max_unorm'
],
skip_zeros
=
False
)
skip_zeros
=
config
[
'skip_zeros'
]
)
elif
state
[
'state1'
].
dtype
==
torch
.
uint8
and
not
config
[
'block_wise'
]:
elif
state
[
'state1'
].
dtype
==
torch
.
uint8
and
not
config
[
'block_wise'
]:
F
.
optimizer_update_8bit
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
None
,
config
[
'betas'
][
0
],
config
[
'betas'
][
1
],
F
.
optimizer_update_8bit
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
None
,
config
[
'betas'
][
0
],
config
[
'betas'
][
1
],
...
@@ -463,4 +463,4 @@ class Optimizer1State(Optimizer8bit):
...
@@ -463,4 +463,4 @@ class Optimizer1State(Optimizer8bit):
F
.
optimizer_update_8bit_blockwise
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
None
,
config
[
'betas'
][
0
],
config
[
'betas'
][
1
],
F
.
optimizer_update_8bit_blockwise
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
None
,
config
[
'betas'
][
0
],
config
[
'betas'
][
1
],
config
[
'eps'
],
step
,
config
[
'lr'
],
config
[
'eps'
],
step
,
config
[
'lr'
],
state
[
'qmap1'
],
None
,
state
[
'absmax1'
],
None
,
state
[
'qmap1'
],
None
,
state
[
'absmax1'
],
None
,
config
[
'weight_decay'
],
gnorm_scale
=
gnorm_scale
,
skip_zeros
=
False
)
config
[
'weight_decay'
],
gnorm_scale
=
gnorm_scale
,
skip_zeros
=
config
[
'skip_zeros'
]
)
csrc/kernels.cu
View file @
a6eae2e7
...
@@ -715,9 +715,12 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
...
@@ -715,9 +715,12 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
switch
(
OPTIMIZER
)
switch
(
OPTIMIZER
)
{
{
case
ADAM
:
case
ADAM
:
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
((
float
)
g_vals
[
j
]));
if
(
!
skip_zeros
||
(
skip_zeros
&&
g_vals
[
j
]
!=
(
T
)
0.0
))
s2_vals
[
j
]
=
s2_vals
[
j
]
*
beta2
+
((
1.0
f
-
beta2
)
*
(((
float
)
g_vals
[
j
])
*
((
float
)
g_vals
[
j
])));
{
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
+
(
update_scale
*
step_size
*
(
s1_vals
[
j
]
/
(
sqrtf
(
s2_vals
[
j
])
+
(
eps
*
correction2
))));
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
((
float
)
g_vals
[
j
]));
s2_vals
[
j
]
=
s2_vals
[
j
]
*
beta2
+
((
1.0
f
-
beta2
)
*
(((
float
)
g_vals
[
j
])
*
((
float
)
g_vals
[
j
])));
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
+
(
update_scale
*
step_size
*
(
s1_vals
[
j
]
/
(
sqrtf
(
s2_vals
[
j
])
+
(
eps
*
correction2
))));
}
break
;
break
;
}
}
}
}
...
@@ -865,21 +868,24 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
...
@@ -865,21 +868,24 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
# pragma unroll 4
# pragma unroll 4
for
(
unsigned
int
j
=
0
;
j
<
NUM_PER_THREAD
;
j
++
)
for
(
unsigned
int
j
=
0
;
j
<
NUM_PER_THREAD
;
j
++
)
{
{
switch
(
OPTIMIZER
)
if
(
!
skip_zeros
||
(
skip_zeros
&&
g_vals
[
j
]
!=
(
T
)
0.0
))
{
{
case
MOMENTUM
:
switch
(
OPTIMIZER
)
if
(
step
==
1
)
{
s1_vals
[
j
]
=
(
float
)
g_vals
[
j
];
case
MOMENTUM
:
else
if
(
step
==
1
)
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
float
)
g_vals
[
j
]);
s1_vals
[
j
]
=
(
float
)
g_vals
[
j
];
else
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
+
update_scale
*
(
-
lr
*
(
s1_vals
[
j
]));
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
float
)
g_vals
[
j
]);
break
;
case
RMSPROP
:
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
+
update_scale
*
(
-
lr
*
(
s1_vals
[
j
]));
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
((
float
)
g_vals
[
j
])
*
((
float
)
g_vals
[
j
]));
break
;
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
update_scale
*
(
lr
*
__fdividef
((
float
)
g_vals
[
j
],
sqrtf
((
float
)
s1_vals
[
j
])
+
eps
));
case
RMSPROP
:
break
;
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
((
float
)
g_vals
[
j
])
*
((
float
)
g_vals
[
j
]));
}
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
update_scale
*
(
lr
*
__fdividef
((
float
)
g_vals
[
j
],
sqrtf
((
float
)
s1_vals
[
j
])
+
eps
));
break
;
}
}
}
}
__syncthreads
();
__syncthreads
();
...
@@ -1469,11 +1475,14 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
...
@@ -1469,11 +1475,14 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
{
{
g_val
=
float
(
g_vals
[
j
]);
g_val
=
float
(
g_vals
[
j
]);
g_val
*=
gnorm_scale
;
g_val
*=
gnorm_scale
;
s1_vals
[
j
]
=
smem_quantiles1
[
lane_id
][
c1s
[
j
]]
*
absmax1
[
i
/
BLOCK_SIZE
];
if
(
!
skip_zeros
||
(
skip_zeros
&&
g_vals
[
j
]
!=
(
T
)
0.0
))
s1_vals
[
j
]
=
(
s1_vals
[
j
]
*
beta1
)
+
(((
1.0
f
-
beta1
)
*
g_val
));
{
s1_vals
[
j
]
=
smem_quantiles1
[
lane_id
][
c1s
[
j
]]
*
absmax1
[
i
/
BLOCK_SIZE
];
s1_vals
[
j
]
=
(
s1_vals
[
j
]
*
beta1
)
+
(((
1.0
f
-
beta1
)
*
g_val
));
s2_vals
[
j
]
=
smem_quantiles2
[
lane_id
][
c2s
[
j
]]
*
absmax2
[
i
/
BLOCK_SIZE
];
s2_vals
[
j
]
=
smem_quantiles2
[
lane_id
][
c2s
[
j
]]
*
absmax2
[
i
/
BLOCK_SIZE
];
s2_vals
[
j
]
=
(
s2_vals
[
j
]
*
beta2
)
+
(((
1.0
f
-
beta2
)
*
g_val
*
g_val
));
s2_vals
[
j
]
=
(
s2_vals
[
j
]
*
beta2
)
+
(((
1.0
f
-
beta2
)
*
g_val
*
g_val
));
}
new_local_abs_max1
=
fmaxf
(
new_local_abs_max1
,
fabsf
(
s1_vals
[
j
]));
new_local_abs_max1
=
fmaxf
(
new_local_abs_max1
,
fabsf
(
s1_vals
[
j
]));
new_local_abs_max2
=
fmaxf
(
new_local_abs_max2
,
fabsf
(
s2_vals
[
j
]));
new_local_abs_max2
=
fmaxf
(
new_local_abs_max2
,
fabsf
(
s2_vals
[
j
]));
...
@@ -1509,9 +1518,12 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
...
@@ -1509,9 +1518,12 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
# pragma unroll N_PER_TH
# pragma unroll N_PER_TH
for
(
unsigned
int
j
=
0
;
j
<
N_PER_TH
;
j
++
)
for
(
unsigned
int
j
=
0
;
j
<
N_PER_TH
;
j
++
)
{
{
g_vals
[
j
]
=
(
T
)(((
float
)
g_vals
[
j
])
+
((
step_size
*
(
__fdividef
(
s1_vals
[
j
],(
sqrtf
(
s2_vals
[
j
])
+
(
correction2
*
eps
)))))));
if
(
!
skip_zeros
||
(
skip_zeros
&&
g_vals
[
j
]
!=
(
T
)
0.0
))
if
(
weight_decay
>
0.0
f
)
{
g_vals
[
j
]
=
((
float
)
g_vals
[
j
])
*
(
1.0
f
-
(
lr
*
weight_decay
));
g_vals
[
j
]
=
(
T
)(((
float
)
g_vals
[
j
])
+
((
step_size
*
(
__fdividef
(
s1_vals
[
j
],(
sqrtf
(
s2_vals
[
j
])
+
(
correction2
*
eps
)))))));
if
(
weight_decay
>
0.0
f
)
g_vals
[
j
]
=
((
float
)
g_vals
[
j
])
*
(
1.0
f
-
(
lr
*
weight_decay
));
}
}
}
// store: 0.85/1.44 -> 2.48/1.57
// store: 0.85/1.44 -> 2.48/1.57
...
@@ -1623,23 +1635,26 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
...
@@ -1623,23 +1635,26 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
{
{
g_val
=
float
(
g_vals
[
j
]);
g_val
=
float
(
g_vals
[
j
]);
g_val
*=
gnorm_scale
;
g_val
*=
gnorm_scale
;
if
(
weight_decay
>
0.0
f
)
if
(
!
skip_zeros
||
(
skip_zeros
&&
g_vals
[
j
]
!=
(
T
)
0.0
))
g_val
+=
((
float
)
p_vals
[
j
])
*
weight_decay
;
{
if
(
weight_decay
>
0.0
f
)
s1_vals
[
j
]
=
smem_quantiles1
[
lane_id
][
c1s
[
j
]]
*
absmax1
[
i
/
BLOCK_SIZE
];
g_val
+=
((
float
)
p_vals
[
j
])
*
weight_decay
;
switch
(
OPTIMIZER
)
s1_vals
[
j
]
=
smem_quantiles1
[
lane_id
][
c1s
[
j
]]
*
absmax1
[
i
/
BLOCK_SIZE
];
{
case
MOMENTUM
:
switch
(
OPTIMIZER
)
if
(
step
==
1
)
{
s1_vals
[
j
]
=
g_val
;
case
MOMENTUM
:
else
if
(
step
==
1
)
s1_vals
[
j
]
=
(
s1_vals
[
j
]
*
beta1
)
+
g_val
;
s1_vals
[
j
]
=
g_val
;
break
;
else
case
RMSPROP
:
s1_vals
[
j
]
=
(
s1_vals
[
j
]
*
beta1
)
+
g_val
;
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
(
g_val
*
g_val
));
break
;
break
;
case
RMSPROP
:
}
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
(
g_val
*
g_val
));
break
;
}
}
new_local_abs_max1
=
fmaxf
(
new_local_abs_max1
,
fabsf
(
s1_vals
[
j
]));
new_local_abs_max1
=
fmaxf
(
new_local_abs_max1
,
fabsf
(
s1_vals
[
j
]));
}
}
...
@@ -1662,16 +1677,19 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
...
@@ -1662,16 +1677,19 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
# pragma unroll N_PER_TH
# pragma unroll N_PER_TH
for
(
unsigned
int
j
=
0
;
j
<
N_PER_TH
;
j
++
)
for
(
unsigned
int
j
=
0
;
j
<
N_PER_TH
;
j
++
)
{
{
switch
(
OPTIMIZER
)
if
(
!
skip_zeros
||
(
skip_zeros
&&
g_vals
[
j
]
!=
(
T
)
0.0
))
{
{
case
MOMENTUM
:
switch
(
OPTIMIZER
)
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
lr
*
(
s1_vals
[
j
]);
{
break
;
case
MOMENTUM
:
case
RMSPROP
:
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
lr
*
(
s1_vals
[
j
]);
g_val
=
g_vals
[
j
];
break
;
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
lr
*
(
__fdividef
(
g_val
,
sqrtf
(
s1_vals
[
j
])
+
eps
));
case
RMSPROP
:
break
;
g_val
=
g_vals
[
j
];
}
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
lr
*
(
__fdividef
(
g_val
,
sqrtf
(
s1_vals
[
j
])
+
eps
));
break
;
}
}
}
}
// store: 0.85/1.44 -> 2.48/1.57
// store: 0.85/1.44 -> 2.48/1.57
...
...
csrc/pythonInterface.c
View file @
a6eae2e7
...
@@ -110,7 +110,7 @@ extern "C"
...
@@ -110,7 +110,7 @@ extern "C"
float eps, int step, float lr, \
float eps, int step, float lr, \
float* quantiles1, float* quantiles2, \
float* quantiles1, float* quantiles2, \
float* max1, float* max2, float* new_max1, float* new_max2, \
float* max1, float* max2, float* new_max1, float* new_max2, \
float weight_decay, float gnorm_scale,
bool skip_zeros,
int n) \
float weight_decay, float gnorm_scale, int n) \
{ \
{ \
name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment