Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
6a3f9fda
"...git@developer.sourcefind.cn:tsoc/superbenchmark.git" did not exist on "b85f685169ff9bcf111dd5071a11f7690c1ce523"
Unverified
Commit
6a3f9fda
authored
Mar 25, 2022
by
LuGY
Committed by
GitHub
Mar 25, 2022
Browse files
[cuda] modify the fused adam, support hybrid of fp16 and fp32 (#497)
parent
920c5889
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
258 additions
and
148 deletions
+258
-148
colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu
colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu
+9
-9
colossalai/kernel/cuda_native/csrc/type_shim.h
colossalai/kernel/cuda_native/csrc/type_shim.h
+30
-0
colossalai/nn/optimizer/fused_adam.py
colossalai/nn/optimizer/fused_adam.py
+16
-27
tests/test_optimizer/unittest_cpu_adam.py
tests/test_optimizer/unittest_cpu_adam.py
+44
-112
tests/test_optimizer/unittest_fused_adam.py
tests/test_optimizer/unittest_fused_adam.py
+61
-0
tests/test_optimizer/unittest_fused_adam_kernel.py
tests/test_optimizer/unittest_fused_adam_kernel.py
+98
-0
No files found.
colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu
View file @
6a3f9fda
...
...
@@ -22,7 +22,7 @@ typedef enum
using
MATH_T
=
float
;
template
<
typename
T
>
template
<
typename
T
_g
,
typename
T_p
>
struct
AdamFunctor
{
__device__
__forceinline__
void
operator
()(
...
...
@@ -50,16 +50,16 @@ struct AdamFunctor
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
T
*
g
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
T
_g
*
g
=
(
T
_g
*
)
tl
.
addresses
[
0
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
T
_p
*
p
=
(
T
_p
*
)
tl
.
addresses
[
1
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
T
*
m
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
T
_p
*
m
=
(
T
_p
*
)
tl
.
addresses
[
2
][
tensor_loc
];
m
+=
chunk_idx
*
chunk_size
;
T
*
v
=
(
T
*
)
tl
.
addresses
[
3
][
tensor_loc
];
T
_p
*
v
=
(
T
_p
*
)
tl
.
addresses
[
3
][
tensor_loc
];
v
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
...
@@ -155,15 +155,15 @@ void multi_tensor_adam_cuda(
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
}
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"adam"
,
DISPATCH_FLOAT_AND_HALF_FOR_G_P
(
tensor_lists
[
0
][
0
].
scalar_type
(),
tensor_lists
[
1
][
0
].
scalar_type
(),
0
,
"adam"
,
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
scalar_t_0
>
(),
AdamFunctor
<
g_scalar_t_0
,
p_
scalar_t_0
>
(),
beta1
,
beta2
,
bias_correction1
,
...
...
colossalai/kernel/cuda_native/csrc/type_shim.h
View file @
6a3f9fda
...
...
@@ -173,6 +173,36 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \
if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) \
{ \
using g_scalar_t_##LEVEL = float; \
using p_scalar_t_##LEVEL = float; \
__VA_ARGS__; \
} \
else if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Half) \
{ \
using g_scalar_t_##LEVEL = float; \
using p_scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
} \
else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Float) \
{ \
using g_scalar_t_##LEVEL = at::Half; \
using p_scalar_t_##LEVEL = float; \
__VA_ARGS__; \
} \
else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) \
{ \
using g_scalar_t_##LEVEL = at::Half; \
using p_scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
} \
else \
{ \
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), "'"); \
} \
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
*
x
,
T
val
,
...
...
colossalai/nn/optimizer/fused_adam.py
View file @
6a3f9fda
...
...
@@ -10,7 +10,7 @@ class FusedAdam(torch.optim.Optimizer):
"""Implements Adam algorithm.
Currently GPU-only. Requires ColossalAI to be installed via
``pip install
-v --no-cache-dir --global-option="--cuda_ext" ./
``.
``pip install
.
``.
This version of fused Adam implements 2 fusions.
...
...
@@ -18,7 +18,7 @@ class FusedAdam(torch.optim.Optimizer):
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
:class:`colossalai.nn.optimizer.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,
or ``torch.optim.Adam`` with ``adam
_
w_mode=False``
or ``torch.optim.Adam`` with ``adamw_mode=False``
:class:`colossalai.nn.optimizer.FusedAdam` may be used with or without Amp.
...
...
@@ -36,7 +36,7 @@ class FusedAdam(torch.optim.Optimizer):
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
adam
_
w_mode (boolean, optional): Apply L2 regularization or weight decay
adamw_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True)
...
...
@@ -53,7 +53,7 @@ class FusedAdam(torch.optim.Optimizer):
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
adam
_
w_mode
=
True
,
adamw_mode
=
True
,
weight_decay
=
0.
,
amsgrad
=
False
,
set_grad_none
=
True
):
...
...
@@ -62,7 +62,7 @@ class FusedAdam(torch.optim.Optimizer):
raise
RuntimeError
(
'FusedAdam does not support the AMSGrad variant.'
)
defaults
=
dict
(
lr
=
lr
,
bias_correction
=
bias_correction
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
)
super
(
FusedAdam
,
self
).
__init__
(
params
,
defaults
)
self
.
adam
_
w_mode
=
1
if
adam
_
w_mode
else
0
self
.
adamw_mode
=
1
if
adamw_mode
else
0
self
.
set_grad_none
=
set_grad_none
if
multi_tensor_applier
.
available
:
import
colossal_C
...
...
@@ -109,8 +109,7 @@ class FusedAdam(torch.optim.Optimizer):
group
[
'step'
]
=
1
# create lists for multi-tensor apply
g_16
,
p_16
,
m_16
,
v_16
=
[],
[],
[],
[]
g_32
,
p_32
,
m_32
,
v_32
=
[],
[],
[],
[]
g_l
,
p_l
,
m_l
,
v_l
=
[],
[],
[],
[]
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
...
...
@@ -127,26 +126,16 @@ class FusedAdam(torch.optim.Optimizer):
# Exponential moving average of squared gradient values
state
[
'exp_avg_sq'
]
=
torch
.
zeros_like
(
p
.
data
)
if
p
.
dtype
==
torch
.
float16
:
g_16
.
append
(
p
.
grad
.
data
)
p_16
.
append
(
p
.
data
)
m_16
.
append
(
state
[
'exp_avg'
])
v_16
.
append
(
state
[
'exp_avg_sq'
])
elif
p
.
dtype
==
torch
.
float32
:
g_32
.
append
(
p
.
grad
.
data
)
p_32
.
append
(
p
.
data
)
m_32
.
append
(
state
[
'exp_avg'
])
v_32
.
append
(
state
[
'exp_avg_sq'
])
else
:
if
p
.
dtype
not
in
[
torch
.
float16
,
torch
.
float32
]:
raise
RuntimeError
(
'FusedAdam only support fp16 and fp32.'
)
if
(
len
(
g_16
)
>
0
):
multi_tensor_applier
(
self
.
multi_tensor_adam
,
self
.
_dummy_overflow_buf
,
[
g_16
,
p_16
,
m_16
,
v_16
],
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
group
[
'step'
],
self
.
adam_w_mode
,
bias_correction
,
group
[
'weight_decay
'
])
if
(
len
(
g_32
)
>
0
):
multi_tensor_applier
(
self
.
multi_tensor_adam
,
self
.
_dummy_overflow_buf
,
[
g_
32
,
p_
32
,
m_
32
,
v_
32
],
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
group
[
'step'
],
self
.
adam
_
w_mode
,
bias_correction
,
group
[
'weight_decay'
])
g_l
.
append
(
p
.
grad
.
data
)
p_l
.
append
(
p
.
data
)
m_l
.
append
(
state
[
'exp_avg'
])
v_l
.
append
(
state
[
'exp_avg_sq
'
])
multi_tensor_applier
(
self
.
multi_tensor_adam
,
self
.
_dummy_overflow_buf
,
[
g_
l
,
p_
l
,
m_
l
,
v_
l
],
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
group
[
'step'
],
self
.
adamw_mode
,
bias_correction
,
group
[
'weight_decay'
])
return
loss
tests/test_optimizer/unittest_cpu_adam.py
View file @
6a3f9fda
# BSD 3-Clause License
#
# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the psutil authors nor the names of its contributors
# may be used to endorse or promote products derived from this software without
# specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import
math
import
torch
try
:
import
cpu_adam
except
ImportError
:
raise
ImportError
(
"import cpu_adam error"
)
from
colossalai.testing
import
parameterize
def
torch_adam_update
(
...
...
@@ -71,45 +40,46 @@ def torch_adam_update(
param
.
addcdiv_
(
exp_avg
,
denom
,
value
=-
step_size
)
class
Test
():
def
__init__
(
self
):
self
.
opt_id
=
0
def
assertLess
(
self
,
data_diff
,
threshold
,
msg
):
assert
data_diff
<
threshold
,
msg
def
assertTrue
(
self
,
condition
,
msg
):
assert
condition
,
msg
def
check_res
(
self
,
step
,
lr
,
eps
,
beta1
,
beta2
,
weight_decay
,
shape
,
grad_dtype
,
loss_scale
,
use_adamw
,
cpu_adam_op
,
):
p_data
=
torch
.
rand
(
shape
,
dtype
=
grad_dtype
)
def
assertLess
(
data_diff
,
threshold
,
msg
):
assert
data_diff
<
threshold
,
msg
def
assertTrue
(
condition
,
msg
):
assert
condition
,
msg
@
parameterize
(
'adamw'
,
[
True
,
False
])
@
parameterize
(
'step'
,
[
1
,
2
])
@
parameterize
(
'loss_scale'
,
[
-
1
,
2
**
5
])
@
parameterize
(
'p_dtype'
,
[
torch
.
float
,
torch
.
half
])
@
parameterize
(
'g_dtype'
,
[
torch
.
float
,
torch
.
half
])
def
test_cpu_adam
(
adamw
,
step
,
loss_scale
,
p_dtype
,
g_dtype
):
lr
=
1e-3
beta1
,
beta2
=
0.9
,
0.999
eps
=
1e-8
weight_decay
=
0
for
i
in
range
(
1024
):
p_data
=
torch
.
rand
(
64
,
dtype
=
p_dtype
)
p_data_copy
=
p_data
.
clone
().
float
()
p_grad
=
torch
.
rand
(
shape
,
dtype
=
g
rad
_dtype
)
p_grad
=
torch
.
rand
(
64
,
dtype
=
g_dtype
)
if
loss_scale
>
0
:
p_grad
.
mul_
(
loss_scale
)
p_grad_copy
=
p_grad
.
clone
().
float
()
exp_avg
=
torch
.
rand
(
shape
)
exp_avg
=
torch
.
rand
(
p_data
.
shape
)
exp_avg_copy
=
exp_avg
.
clone
()
exp_avg_sq
=
torch
.
rand
(
shape
)
exp_avg_sq
=
torch
.
rand
(
p_data
.
shape
)
exp_avg_sq_copy
=
exp_avg_sq
.
clone
()
cpu_adam_op
.
create_adam
(
0
,
lr
,
beta1
,
beta2
,
eps
,
weight_decay
,
use_adamw
,
True
)
try
:
import
cpu_adam
cpu_adam_op
=
cpu_adam
except
:
raise
ImportError
(
"..."
)
cpu_adam_op
.
create_adam
(
0
,
lr
,
beta1
,
beta2
,
eps
,
weight_decay
,
adamw
,
False
)
cpu_adam_op
.
adam_update
(
self
.
opt_id
,
0
,
step
,
lr
,
beta1
,
...
...
@@ -136,62 +106,24 @@ class Test():
exp_avg_copy
,
exp_avg_sq_copy
,
loss_scale
,
use_
adamw
,
adamw
,
)
if
loss_scale
>
0
:
p_grad
.
div_
(
loss_scale
)
var
=
p_data_copy
-
p_data
data_diff
=
torch
.
max
(
torch
.
abs
(
var
))
threshold
=
2e-3
if
grad_dtype
else
1e-4
self
.
assertLess
(
threshold
=
1e-3
print
(
f
"p_data diff
{
data_diff
}
. failed check, step
{
step
}
, lr
{
lr
}
eps "
f
"
{
eps
}
beta1
{
beta1
}
beta2
{
beta2
}
weight_decay
{
weight_decay
}
p_dtype
{
p_dtype
}
, g_dtype
{
g_dtype
}
"
)
assertLess
(
data_diff
,
threshold
,
f
"p_data diff
{
data_diff
}
. failed check, step
{
step
}
, lr
{
lr
}
eps "
f
"
{
eps
}
beta1
{
beta1
}
beta2
{
beta2
}
weight_decay
{
weight_decay
}
loss_scale
{
loss_scal
e
}
g
rad
_dtype
{
g
rad
_dtype
}
"
,
f
"p_data diff
{
data_diff
}
. failed check, step
{
step
}
, lr
{
lr
}
, loss_scale
{
loss_scale
}
,
eps "
f
"
{
eps
}
beta1
{
beta1
}
beta2
{
beta2
}
weight_decay
{
weight_decay
}
p_dtype
{
p_dtyp
e
}
,
g_dtype
{
g_dtype
}
"
,
)
max_grad_diff
=
torch
.
max
(
torch
.
abs
(
p_grad_copy
-
p_grad
))
self
.
assertTrue
(
max_grad_diff
<
threshold
,
f
"diff
{
max_grad_diff
}
"
)
assertTrue
(
max_grad_diff
<
threshold
,
f
"diff
{
max_grad_diff
}
"
)
max_exp_avg_diff
=
torch
.
max
(
torch
.
abs
(
exp_avg_copy
-
exp_avg
))
self
.
assertTrue
(
max_exp_avg_diff
<
threshold
,
f
"max_exp_avg_diff
{
max_exp_avg_diff
}
"
)
assertTrue
(
max_exp_avg_diff
<
threshold
,
f
"max_exp_avg_diff
{
max_exp_avg_diff
}
"
)
max_exp_avg_sq_diff
=
torch
.
max
(
torch
.
abs
(
exp_avg_sq_copy
-
exp_avg_sq
))
self
.
assertTrue
(
max_exp_avg_sq_diff
<
threshold
,
f
"max_exp_avg_sq_diff
{
max_exp_avg_sq_diff
}
"
)
def
test_cpu_adam
(
self
):
lr
=
0.9
eps
=
1e-6
weight_decay
=
0
for
use_adamw
in
[
False
,
True
]:
for
shape
in
[(
23
,),
(
8
,
24
)]:
for
step
in
range
(
1
,
2
):
for
lr
in
[
0.01
]:
for
eps
in
[
1e-8
]:
for
beta1
in
[
0.9
]:
for
beta2
in
[
0.999
]:
for
weight_decay
in
[
0.001
]:
for
grad_dtype
in
[
torch
.
half
,
torch
.
float
]:
for
loss_scale
in
[
-
1
,
2
**
5
]:
self
.
check_res
(
step
,
lr
,
eps
,
beta1
,
beta2
,
weight_decay
,
shape
,
grad_dtype
,
loss_scale
,
use_adamw
,
cpu_adam
,
)
def
test_cpu_adam
():
test_case
=
Test
()
test_case
.
test_cpu_adam
()
if
__name__
==
"__main__"
:
test
=
Test
()
test
.
test_cpu_adam
()
assertTrue
(
max_exp_avg_sq_diff
<
threshold
,
f
"max_exp_avg_sq_diff
{
max_exp_avg_sq_diff
}
"
)
tests/test_optimizer/unittest_fused_adam.py
0 → 100644
View file @
6a3f9fda
import
torch
import
torch.nn
as
nn
from
torch.optim.adam
import
Adam
from
torch.optim
import
AdamW
from
colossalai.nn.optimizer.fused_adam
import
FusedAdam
from
colossalai.testing
import
parameterize
class
FC
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
fc
=
nn
.
Sequential
(
nn
.
Linear
(
64
,
64
))
def
forward
(
self
,
x
):
return
self
.
fc
(
x
)
@
parameterize
(
'adamw'
,
[
False
,
True
])
@
parameterize
(
'p_dtype'
,
[
torch
.
float
,
torch
.
half
])
@
parameterize
(
'g_dtype'
,
[
torch
.
float
,
torch
.
half
])
def
test_adam
(
adamw
,
p_dtype
,
g_dtype
):
model
=
FC
().
cuda
().
to
(
p_dtype
)
state
=
model
.
state_dict
()
model_copy
=
FC
().
cuda
().
to
(
p_dtype
)
model_copy
.
load_state_dict
(
state
.
copy
())
if
adamw
:
optim
=
FusedAdam
(
model
.
parameters
(),
lr
=
1e-3
,
adamw_mode
=
True
)
torch_optim
=
AdamW
(
model_copy
.
parameters
(),
lr
=
1e-3
)
else
:
optim
=
FusedAdam
(
model
.
parameters
(),
lr
=
1e-3
)
torch_optim
=
Adam
(
model_copy
.
parameters
(),
lr
=
1e-3
)
data
=
torch
.
rand
(
1024
,
64
).
cuda
().
to
(
p_dtype
)
data_copy
=
data
.
clone
()
label
=
torch
.
rand
(
1024
,
64
).
cuda
().
to
(
p_dtype
)
for
d
,
l
in
zip
(
data
,
label
):
y
=
model
(
d
)
loss
=
((
l
-
y
)
**
2
).
sum
()
optim
.
zero_grad
()
loss
.
backward
()
if
p_dtype
!=
g_dtype
:
for
i
in
range
(
len
(
optim
.
param_groups
[
0
][
'params'
])):
optim
.
param_groups
[
0
][
'params'
][
i
].
grad
.
data
=
optim
.
param_groups
[
0
][
'params'
][
i
].
grad
.
data
.
to
(
g_dtype
)
optim
.
step
()
for
d
,
l
in
zip
(
data_copy
,
label
):
y
=
model_copy
(
d
)
loss
=
((
l
-
y
)
**
2
).
sum
()
torch_optim
.
zero_grad
()
loss
.
backward
()
torch_optim
.
step
()
assert
len
(
optim
.
param_groups
[
0
][
'params'
])
==
len
(
torch_optim
.
param_groups
[
0
][
'params'
])
for
i
in
range
(
len
(
optim
.
param_groups
[
0
][
'params'
])):
if
torch
.
isnan
(
optim
.
param_groups
[
0
][
'params'
][
i
]).
any
()
\
or
torch
.
isnan
(
torch_optim
.
param_groups
[
0
][
'params'
][
i
]).
any
():
continue
assert
torch
.
allclose
(
optim
.
param_groups
[
0
][
'params'
][
i
],
torch_optim
.
param_groups
[
0
][
'params'
][
i
],
2e-3
,
2e-3
)
tests/test_optimizer/unittest_fused_adam_kernel.py
0 → 100644
View file @
6a3f9fda
from
numpy
import
dtype
import
torch
import
torch.nn
as
nn
import
math
from
colossalai.testing
import
parameterize
from
colossalai.utils
import
multi_tensor_applier
def
torch_adam_update
(
step
,
lr
,
beta1
,
beta2
,
eps
,
weight_decay
,
param
,
grad
,
exp_avg
,
exp_avg_sq
,
loss_scale
,
use_adamw
,
):
if
loss_scale
>
0
:
grad
.
div_
(
loss_scale
)
bias_correction1
=
1
-
beta1
**
step
bias_correction2
=
1
-
beta2
**
step
if
weight_decay
!=
0
:
if
use_adamw
:
# Perform stepweight decay
param
.
mul_
(
1
-
lr
*
weight_decay
)
else
:
grad
=
grad
.
add
(
param
,
alpha
=
weight_decay
)
# Decay the first and second moment running average coefficient
exp_avg
.
mul_
(
beta1
).
add_
(
grad
,
alpha
=
1
-
beta1
)
exp_avg_sq
.
mul_
(
beta2
).
addcmul_
(
grad
,
grad
,
value
=
1
-
beta2
)
denom
=
(
exp_avg_sq
.
sqrt
()
/
math
.
sqrt
(
bias_correction2
)).
add_
(
eps
)
step_size
=
lr
/
bias_correction1
param
.
addcdiv_
(
exp_avg
,
denom
,
value
=-
step_size
)
@
parameterize
(
'adamw'
,
[
False
,
True
])
@
parameterize
(
'step'
,
[
1
,
2
])
@
parameterize
(
'p_dtype'
,
[
torch
.
float
,
torch
.
half
])
@
parameterize
(
'g_dtype'
,
[
torch
.
float
,
torch
.
half
])
def
test_adam
(
adamw
,
step
,
p_dtype
,
g_dtype
):
try
:
import
colossal_C
fused_adam
=
colossal_C
.
multi_tensor_adam
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
except
:
raise
ImportError
(
"No colossal_C kernel installed."
)
count
=
0
for
i
in
range
(
1024
):
p
=
torch
.
rand
(
64
,
dtype
=
p_dtype
).
cuda
()
p_copy
=
p
.
clone
().
float
()
g
=
torch
.
rand
(
p
.
shape
,
dtype
=
g_dtype
).
cuda
()
g_copy
=
g
.
clone
().
float
()
m
=
torch
.
rand
(
p
.
shape
).
cuda
()
m_copy
=
m
.
clone
()
v
=
torch
.
rand
(
p
.
shape
).
cuda
()
v_copy
=
v
.
clone
()
lr
=
1e-3
beta1
,
beta2
=
0.9
,
0.999
eps
=
1e-8
weight_decay
=
0
multi_tensor_applier
(
fused_adam
,
dummy_overflow_buf
,
[[
g
],
[
p
],
[
m
],
[
v
]],
lr
,
beta1
,
beta2
,
eps
,
step
,
adamw
,
True
,
weight_decay
)
torch_adam_update
(
step
,
lr
,
beta1
,
beta2
,
eps
,
weight_decay
,
p_copy
,
# fp32 data
g_copy
,
# fp32 grad
m_copy
,
v_copy
,
-
1
,
adamw
,
)
if
torch
.
isnan
(
p
).
any
()
or
torch
.
isnan
(
p_copy
).
any
():
count
+=
1
continue
assert
count
<
200
,
"too many nans"
assert
torch
.
allclose
(
p
.
to
(
torch
.
float
),
p_copy
.
to
(
torch
.
float
),
1e-5
,
1e-5
),
f
"failed check, adamw
{
adamw
}
, p_dtype
{
p_dtype
}
, g_dtype
{
g_dtype
}
"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment