Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
8618bed0
Commit
8618bed0
authored
Mar 10, 2023
by
Phil Wang
Browse files
swap the order in which momentum and parameters are updated in ops.cu
parent
c5582724
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
3 deletions
+22
-3
csrc/ops.cu
csrc/ops.cu
+22
-3
No files found.
csrc/ops.cu
View file @
8618bed0
...
@@ -120,8 +120,6 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
...
@@ -120,8 +120,6 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
case
MOMENTUM
:
case
MOMENTUM
:
case
RMSPROP
:
case
RMSPROP
:
case
ADAGRAD
:
case
ADAGRAD
:
case
LION
:
if
(
max_unorm
>
0.0
f
)
if
(
max_unorm
>
0.0
f
)
{
{
CUDA_CHECK_RETURN
(
cudaMemset
(
unorm
,
0
,
1
*
sizeof
(
float
)));
CUDA_CHECK_RETURN
(
cudaMemset
(
unorm
,
0
,
1
*
sizeof
(
float
)));
...
@@ -132,6 +130,18 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
...
@@ -132,6 +130,18 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
kOptimizer32bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
1024
>>>
(
g
,
p
,
state1
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
skip_zeros
,
n
);
kOptimizer32bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
1024
>>>
(
g
,
p
,
state1
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
skip_zeros
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
break
;
break
;
case
LION
:
// in lion, the momentum update after the parameter update
kOptimizer32bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
1024
>>>
(
g
,
p
,
state1
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
skip_zeros
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
if
(
max_unorm
>
0.0
f
)
{
CUDA_CHECK_RETURN
(
cudaMemset
(
unorm
,
0
,
1
*
sizeof
(
float
)));
kPreconditionOptimizer32bit1State
<
T
,
OPTIMIZER
,
4096
,
8
><<<
num_blocks
,
512
>>>
(
g
,
p
,
state1
,
unorm
,
beta1
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
break
;
}
}
}
}
...
@@ -164,7 +174,6 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
...
@@ -164,7 +174,6 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
case
MOMENTUM
:
case
MOMENTUM
:
case
RMSPROP
:
case
RMSPROP
:
case
ADAGRAD
:
case
ADAGRAD
:
case
LION
:
CUDA_CHECK_RETURN
(
cudaMemset
(
new_max1
,
0
,
1
*
sizeof
(
float
)));
CUDA_CHECK_RETURN
(
cudaMemset
(
new_max1
,
0
,
1
*
sizeof
(
float
)));
kPreconditionOptimizerStatic8bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
256
>>>
(
p
,
g
,
state1
,
unorm
,
beta1
,
eps
,
step
,
quantiles1
,
max1
,
new_max1
,
weight_decay
,
gnorm_scale
,
n
);
kPreconditionOptimizerStatic8bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
256
>>>
(
p
,
g
,
state1
,
unorm
,
beta1
,
eps
,
step
,
quantiles1
,
max1
,
new_max1
,
weight_decay
,
gnorm_scale
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
...
@@ -172,6 +181,16 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
...
@@ -172,6 +181,16 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
quantiles1
,
max1
,
new_max1
,
weight_decay
,
gnorm_scale
,
n
);
quantiles1
,
max1
,
new_max1
,
weight_decay
,
gnorm_scale
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
break
;
break
;
case
LION
:
// in lion, the momentum update happens after the parameter update
kOptimizerStatic8bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
1024
>>>
(
p
,
g
,
state1
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
eps
,
step
,
lr
,
quantiles1
,
max1
,
new_max1
,
weight_decay
,
gnorm_scale
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaMemset
(
new_max1
,
0
,
1
*
sizeof
(
float
)));
kPreconditionOptimizerStatic8bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
256
>>>
(
p
,
g
,
state1
,
unorm
,
beta1
,
eps
,
step
,
quantiles1
,
max1
,
new_max1
,
weight_decay
,
gnorm_scale
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
break
;
default:
default:
break
;
break
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment