Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e79ea442
Unverified
Commit
e79ea442
authored
Mar 15, 2022
by
Frank Lee
Committed by
GitHub
Mar 15, 2022
Browse files
[fp16] refactored fp16 optimizer (#392)
parent
f8a0e7fb
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
380 additions
and
354 deletions
+380
-354
colossalai/amp/naive_amp/__init__.py
colossalai/amp/naive_amp/__init__.py
+16
-5
colossalai/amp/naive_amp/_fp16_optimizer.py
colossalai/amp/naive_amp/_fp16_optimizer.py
+198
-293
colossalai/amp/naive_amp/_utils.py
colossalai/amp/naive_amp/_utils.py
+40
-0
colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
+0
-2
colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
+11
-3
colossalai/amp/naive_amp/naive_amp.py
colossalai/amp/naive_amp/naive_amp.py
+2
-8
colossalai/initialize.py
colossalai/initialize.py
+1
-1
colossalai/zero/sharded_optim/sharded_optim.py
colossalai/zero/sharded_optim/sharded_optim.py
+28
-41
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+1
-1
tests/test_amp/test_naive_fp16.py
tests/test_amp/test_naive_fp16.py
+83
-0
No files found.
colossalai/amp/naive_amp/__init__.py
View file @
e79ea442
import
inspect
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
colossalai.utils
import
is_no_pp_or_last_stage
from
colossalai.utils
import
is_no_pp_or_last_stage
from
.naive_amp
import
NaiveAMPOptimizer
,
NaiveAMPModel
from
.naive_amp
import
NaiveAMPOptimizer
,
NaiveAMPModel
from
.grad_scaler
import
DynamicGradScaler
,
ConstantGradScaler
def
convert_to_naive_amp
(
model
:
nn
.
Module
,
def
convert_to_naive_amp
(
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
amp_config
):
optimizer
:
Optimizer
,
amp_config
):
"""A helper function to wrap training components with naive AMP modules
"""A helper function to wrap training components with naive AMP modules
:param model: your model object
:param model: your model object
...
@@ -31,7 +30,19 @@ def convert_to_naive_amp(model: nn.Module,
...
@@ -31,7 +30,19 @@ def convert_to_naive_amp(model: nn.Module,
output_to_fp32
=
is_no_pp_or_last_stage
()
output_to_fp32
=
is_no_pp_or_last_stage
()
model
=
NaiveAMPModel
(
model
,
output_to_fp32
=
output_to_fp32
)
model
=
NaiveAMPModel
(
model
,
output_to_fp32
=
output_to_fp32
)
optimizer
=
NaiveAMPOptimizer
(
optimizer
,
**
amp_config
)
use_dynamic_grad_scaler
=
amp_config
.
pop
(
'dynamic_grad_scale'
,
True
)
if
use_dynamic_grad_scaler
:
scaler_class
=
DynamicGradScaler
else
:
scaler_class
=
ConstantGradScaler
sig
=
inspect
.
signature
(
scaler_class
.
__init__
)
kwargs
=
dict
()
for
param
in
sig
.
parameters
.
values
():
if
param
.
name
in
amp_config
:
kwargs
[
param
.
name
]
=
amp_config
.
pop
(
param
.
name
)
grad_scaler
=
scaler_class
(
**
kwargs
)
optimizer
=
NaiveAMPOptimizer
(
optimizer
,
grad_scaler
,
**
amp_config
)
return
model
,
optimizer
return
model
,
optimizer
...
...
colossalai/amp/naive_amp/_fp16_optimizer.py
View file @
e79ea442
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# -*- encoding: utf-8 -*-
# -*- encoding: utf-8 -*-
import
torch
import
torch
import
torch.distributed
as
dist
try
:
try
:
import
colossal_C
import
colossal_C
...
@@ -9,41 +10,30 @@ except:
...
@@ -9,41 +10,30 @@ except:
print
(
'Colossalai should be built with cuda extension to use the FP16 optimizer'
)
print
(
'Colossalai should be built with cuda extension to use the FP16 optimizer'
)
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context
import
ParallelMode
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
(
print_rank_0
,
copy_tensor_parallel_attributes
,
from
colossalai.utils
import
(
copy_tensor_parallel_attributes
,
clip_grad_norm_fp32
,
multi_tensor_applier
)
clip_grad_norm_fp32
,
count_zeros_fp32
,
multi_tensor_applier
)
from
torch.distributed
import
ProcessGroup
from
.grad_scaler
import
BaseGradScaler
from
._utils
import
has_inf_or_nan
,
zero_gard_by_list
def
_zero_grad_group_helper
(
group
,
set_to_none
):
__all__
=
[
'FP16Optimizer'
]
"""Zero out the gradient for a group of parameters.
Note: copied from torch.optim.optimizer."""
for
param
in
group
:
if
param
.
grad
is
not
None
:
if
set_to_none
:
param
.
grad
=
None
else
:
if
param
.
grad
.
grad_fn
is
not
None
:
param
.
grad
.
detach_
()
else
:
param
.
grad
.
requires_grad_
(
False
)
param
.
grad
.
zero_
()
def
_multi_tensor_copy_this_to_that
(
this
,
that
,
overflow_buf
=
None
):
def
_multi_tensor_copy_this_to_that
(
this
,
that
,
overflow_buf
=
None
):
"""Use multi-tensor-applier to copy values from one list to another.
"""
adapted from Megatron-LM (https://github.com/NVIDIA/Megatron-LM)
Use multi-tensor-applier to copy values from one list to another.
We don't have a blfoat16 implementation so for now if the overflow_buf
We don't have a blfoat16 implementation so for now if the overflow_buf
is not provided, we default back to simple loop copy to be compatible
is not provided, we default back to simple loop copy to be compatible
with bfloat16."""
with bfloat16.
"""
if
overflow_buf
:
if
overflow_buf
:
overflow_buf
.
fill_
(
0
)
overflow_buf
.
fill_
(
0
)
# Scaling with factor `1.0` is equivalent to copy.
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier
(
colossal_C
.
multi_tensor_scale
,
multi_tensor_applier
(
colossal_C
.
multi_tensor_scale
,
overflow_buf
,
[
this
,
that
],
1.0
)
overflow_buf
,
[
this
,
that
],
1.0
)
else
:
else
:
for
this_
,
that_
in
zip
(
this
,
that
):
for
this_
,
that_
in
zip
(
this
,
that
):
that_
.
copy_
(
this_
)
that_
.
copy_
(
this_
)
...
@@ -111,8 +101,7 @@ class DynamicGradScaler:
...
@@ -111,8 +101,7 @@ class DynamicGradScaler:
self
.
_hysteresis_tracker
-=
1
self
.
_hysteresis_tracker
-=
1
# Now if we are out of hysteresis count, scale down the loss.
# Now if we are out of hysteresis count, scale down the loss.
if
self
.
_hysteresis_tracker
<=
0
:
if
self
.
_hysteresis_tracker
<=
0
:
self
.
_scale
=
torch
.
max
(
self
.
_scale
*
self
.
backoff_factor
,
self
.
_scale
=
torch
.
max
(
self
.
_scale
*
self
.
backoff_factor
,
self
.
min_scale
)
self
.
min_scale
)
if
self
.
verbose
:
if
self
.
verbose
:
self
.
_logger
.
info
(
f
'overflow occurs, loss scale is adjusted to
{
self
.
_scale
}
'
,
ranks
=
[
0
])
self
.
_logger
.
info
(
f
'overflow occurs, loss scale is adjusted to
{
self
.
_scale
}
'
,
ranks
=
[
0
])
else
:
else
:
...
@@ -127,12 +116,13 @@ class DynamicGradScaler:
...
@@ -127,12 +116,13 @@ class DynamicGradScaler:
if
self
.
_max_scale
is
not
None
and
self
.
_scale
>=
self
.
_max_scale
:
if
self
.
_max_scale
is
not
None
and
self
.
_scale
>=
self
.
_max_scale
:
if
self
.
verbose
:
if
self
.
verbose
:
self
.
_logger
.
info
(
self
.
_logger
.
info
(
f
'Current loss scale
{
self
.
_scale
}
has reached the max scale
{
self
.
_max_scale
}
allowed'
,
ranks
=
[
0
])
f
'Current loss scale
{
self
.
_scale
}
has reached the max scale
{
self
.
_max_scale
}
allowed'
,
ranks
=
[
0
])
else
:
else
:
self
.
_scale
=
self
.
_scale
*
self
.
growth_factor
self
.
_scale
=
self
.
_scale
*
self
.
growth_factor
if
self
.
verbose
:
if
self
.
verbose
:
self
.
_logger
.
info
(
self
.
_logger
.
info
(
f
'no consecutive overflow, loss scale is adjusted to
{
self
.
_scale
}
'
,
f
'no consecutive overflow, loss scale is adjusted to
{
self
.
_scale
}
'
,
ranks
=
[
0
])
ranks
=
[
0
])
def
state_dict
(
self
):
def
state_dict
(
self
):
state_dict
=
{}
state_dict
=
{}
...
@@ -173,326 +163,241 @@ class FP16Optimizer(Optimizer):
...
@@ -173,326 +163,241 @@ class FP16Optimizer(Optimizer):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
optimizer
,
optimizer
:
Optimizer
,
clip_grad
=
0
,
grad_scaler
:
BaseGradScaler
,
log_num_zeros_in_grad
=
False
,
verbose
:
bool
=
False
,
initial_scale
=
2
**
32
,
clip_grad_norm
=
0
,
min_scale
=
1
,
dp_process_group
:
ProcessGroup
=
None
,
growth_factor
=
2
,
mp_process_group
:
ProcessGroup
=
None
):
backoff_factor
=
0.5
,
growth_interval
=
1000
,
hysteresis
=
2
,
max_scale
:
int
=
2
**
32
,
verbose
:
bool
=
False
):
# default args for compatibility
bf16
=
False
params_have_main_grad
=
False
# have a defaults for compatibility with pytorch optim
# have a defaults for compatibility with pytorch optim
self
.
defaults
=
optimizer
.
defaults
self
.
_optimizer
=
optimizer
self
.
_defaults
=
optimizer
.
defaults
# log config
self
.
_logger
=
get_dist_logger
()
if
verbose
:
self
.
_logger
.
info
(
f
"
\n
========= FP16 Optimizer Config =========
\n
"
f
"Optimizer:
{
optimizer
.
__class__
.
__name__
}
\n
"
f
"clip_grad =
{
clip_grad
}
\n
"
f
"log_num_zeros_in_grad =
{
log_num_zeros_in_grad
}
\n
"
f
"initial_scale =
{
initial_scale
}
\n
"
f
"min_scale =
{
min_scale
}
\n
"
f
"growth_factor =
{
growth_factor
}
\n
"
f
"backoff_factor =
{
backoff_factor
}
\n
"
f
"growth_interval =
{
growth_interval
}
\n
"
f
"hysteresis =
{
hysteresis
}
\n
"
f
"=========================================="
,
ranks
=
[
0
])
"""Input optimizer is the base optimizer for example Adam."""
self
.
optimizer
=
optimizer
assert
self
.
optimizer
,
'no optimizer is provided.'
# Set gradient clipping and logging params.
self
.
clip_grad
=
clip_grad
self
.
log_num_zeros_in_grad
=
log_num_zeros_in_grad
self
.
params_have_main_grad
=
params_have_main_grad
self
.
bf16
=
bf16
self
.
grad_scaler
=
DynamicGradScaler
(
initial_scale
=
initial_scale
,
min_scale
=
min_scale
,
growth_factor
=
growth_factor
,
backoff_factor
=
backoff_factor
,
growth_interval
=
growth_interval
,
hysteresis
=
hysteresis
,
max_scale
=
max_scale
,
verbose
=
verbose
)
# None grad scaler is only supported for bf16.
if
self
.
grad_scaler
is
None
:
assert
self
.
bf16
,
'fp16 expects a grad scaler.'
# Tensor used to determine if a nan/if has happend.
# Any non-zero value indicates inf/nan.
# Note that we keep this for the cases that grad scaler is none.
# We still record nan/inf if we have a bfloat16 with a grad scaler.
if
self
.
grad_scaler
:
self
.
found_inf
=
torch
.
cuda
.
FloatTensor
([
0.0
])
# Dummy tensor needed for apex multi-apply tensor.
# fp16-related params
# For bfloat, we don't have multi-tensor apply and for now
assert
isinstance
(
grad_scaler
,
BaseGradScaler
)
# we set it to none so the multi-tensor apply gets ignored.
self
.
_grad_scaler
=
grad_scaler
if
bf16
:
self
.
_found_overflow
=
torch
.
cuda
.
FloatTensor
([
0.0
])
self
.
_dummy_overflow_buf
=
None
else
:
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
# In case grad scaler is not passed, define the unity scale.
# misc params
if
self
.
grad_scaler
is
None
:
self
.
_clip_grad_max_norm
=
clip_grad_norm
self
.
_scale_one
=
torch
.
cuda
.
FloatTensor
([
1.0
])
# ======================
# main parameter stuff
# ======================
# Three groups of parameters:
# get process group
# float16_groups: original float16 parameters
def
_get_process_group
(
parallel_mode
):
# fp32_from_float16_groups: fp32 copy of float16 parameters
if
gpc
.
is_initialized
(
ParallelMode
.
DATA
)
and
gpc
.
get_world_size
(
ParallelMode
.
DATA
):
# fp32_from_fp32_groups: original fp32 parameters
return
gpc
.
get_group
(
ParallelMode
.
DATA
)
self
.
float16_groups
=
[]
else
:
self
.
fp32_from_float16_groups
=
[]
return
None
self
.
fp32_from_fp32_groups
=
[]
if
dp_process_group
is
None
:
dp_process_group
=
_get_process_group
(
ParallelMode
.
DATA
)
if
mp_process_group
is
None
:
mp_process_group
=
_get_process_group
(
ParallelMode
.
MODEL
)
self
.
_dp_process_group
=
dp_process_group
self
.
_mp_process_group
=
mp_process_group
# we maintain three groups of parameters
# so that the model can have a mixture
# of fp16 and fp32 params
# fp16_param_groups: the fp16 params of the model
# fp32_master_param_groups: the fp32 params cast from the fp16 param of the model
# fp32_param_groups: the fp32 params of the model
# NOTE:
# 1. fp16_param_groups and fp32_master_param_groups have one-to-one correspondence
# 2. fp32_param_groups and fp16_param_groups are exclusive of each other
self
.
_fp16_param_groups
=
[]
self
.
_fp32_master_param_groups
=
[]
self
.
_fp32_param_groups
=
[]
# For all the groups in the original optimizer:
# For all the groups in the original optimizer:
for
param_group
in
self
.
optimizer
.
param_groups
:
for
param_group
in
self
.
_
optimizer
.
param_groups
:
f
loat
16_params
_this_group
=
[]
f
p
16_params
=
[]
fp32_
params_this_group
=
[]
fp32_
master_params
=
[]
fp32_
from_float16_params_this_group
=
[]
fp32_
params
=
[]
# For all the parameters in this group:
# For all the parameters in this group:
for
i
,
param
in
enumerate
(
param_group
[
'params'
]):
for
i
,
param
in
enumerate
(
param_group
[
'params'
]):
if
param
.
requires_grad
:
if
param
.
requires_grad
:
# float16 params:
# float16 params:
if
param
.
type
()
in
[
'torch.cuda.HalfTensor'
,
if
param
.
type
()
in
[
'torch.cuda.HalfTensor'
]:
'torch.cuda.BFloat16Tensor'
]:
fp16_params
.
append
(
param
)
float16_params_this_group
.
append
(
param
)
# Create a copy
main_param
=
param
.
detach
().
clone
().
float
()
# Copy tensor model parallel attributes.
copy_tensor_parallel_attributes
(
param
,
main_param
)
# if hasattr(param, 'shared'):
# Create a fp32 copy
# main_param.shared = param.shared
fp32_param
=
param
.
detach
().
clone
().
float
()
# Copy tensor model parallel attributes.
copy_tensor_parallel_attributes
(
param
,
fp32_param
)
# Replace the optimizer params with the new fp32 copy.
# Replace the optimizer params with the new fp32 copy.
param_group
[
'params'
][
i
]
=
main_param
param_group
[
'params'
][
i
]
=
fp32_param
fp32_from_float16_params_this_group
.
append
(
main_param
)
fp32_master_params
.
append
(
fp32_param
)
# Reset existing state dict key to the new main param.
# Reset existing state dict key to the new main param.
if
param
in
self
.
optimizer
.
state
:
if
param
in
self
.
_optimizer
.
state
:
self
.
optimizer
.
state
[
main_param
]
\
self
.
_optimizer
.
state
[
fp32_param
]
=
self
.
_optimizer
.
state
.
pop
(
param
)
=
self
.
optimizer
.
state
.
pop
(
param
)
# fp32 params.
# fp32 params.
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
fp32_params_this_group
.
append
(
param
)
fp32_params
.
append
(
param
)
param_group
[
'params'
][
i
]
=
param
else
:
else
:
raise
TypeError
(
'Wrapped parameters must be one of '
raise
TypeError
(
'Expected parameter of type torch.cuda.FloatTensor '
'torch.cuda.FloatTensor, '
f
'or torch.cuda.HalfTensor, but got
{
param
.
type
()
}
'
)
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'
.
format
(
param
.
type
()))
self
.
float16_groups
.
append
(
float16_params_this_group
)
self
.
_fp16_param_groups
.
append
(
fp16_params
)
self
.
fp32_from_float16_groups
.
append
(
self
.
_fp32_master_param_groups
.
append
(
fp32_master_params
)
fp32_from_float16_params_this_group
)
self
.
_fp32_param_groups
.
append
(
fp32_params
)
self
.
fp32_from_fp32_groups
.
append
(
fp32_params_this_group
)
# Leverage state_dict() and load_state_dict() to
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
# recast preexisting per-param state tensors
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
self
.
_optimizer
.
load_state_dict
(
self
.
_optimizer
.
state_dict
())
def
zero_grad
(
self
,
set_to_none
=
False
):
"""We only need to zero the model related parameters, i.e.,
float16_groups & fp32_from_fp32_groups."""
for
group
in
self
.
float16_groups
:
_zero_grad_group_helper
(
group
,
set_to_none
)
for
group
in
self
.
fp32_from_fp32_groups
:
_zero_grad_group_helper
(
group
,
set_to_none
)
def
get_loss_scale
(
self
):
if
self
.
grad_scaler
is
None
:
return
self
.
_scale_one
return
self
.
grad_scaler
.
scale
def
_copy_model_grads_to_main_grads
(
self
):
# This only needs to be done for the float16 group.
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
self
.
fp32_from_float16_groups
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
if
self
.
params_have_main_grad
:
main_param
.
grad
=
model_param
.
main_grad
.
float
()
else
:
if
model_param
.
grad
is
not
None
:
main_param
.
grad
=
model_param
.
grad
.
float
()
# For fp32 grads, we need to reset the grads to main grad.
if
self
.
params_have_main_grad
:
for
model_group
in
self
.
fp32_from_fp32_groups
:
for
model_param
in
model_group
:
model_param
.
grad
=
model_param
.
main_grad
def
_unscale_main_grads_and_check_for_nan
(
self
):
main_grads
=
[]
# fp32 params fromm float16 ones.
for
main_group
in
self
.
fp32_from_float16_groups
:
for
main_param
in
main_group
:
if
main_param
.
grad
is
not
None
:
main_grads
.
append
(
main_param
.
grad
.
data
)
# Append fp32 parameters.
for
main_group
in
self
.
fp32_from_fp32_groups
:
for
main_param
in
main_group
:
if
main_param
.
grad
is
not
None
:
main_grads
.
append
(
main_param
.
grad
.
data
)
# Reset found inf.
self
.
found_inf
.
fill_
(
0.0
)
# Unscale and set found inf/nan
torch
.
_amp_foreach_non_finite_check_and_unscale_
(
main_grads
,
self
.
found_inf
,
self
.
grad_scaler
.
inv_scale
)
# Update across all model parallel instances.
torch
.
distributed
.
all_reduce
(
self
.
found_inf
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
gpc
.
get_group
(
ParallelMode
.
MODEL
))
# Check for nan.
found_inf_flag
=
(
self
.
found_inf
.
item
()
>
0
)
return
found_inf_flag
def
_get_model_and_main_params_data_float16
(
self
):
model_data
=
[]
main_data
=
[]
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
self
.
fp32_from_float16_groups
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
model_data
.
append
(
model_param
.
data
)
main_data
.
append
(
main_param
.
data
)
return
model_data
,
main_data
def
_copy_main_params_to_model_params
(
self
):
# Only needed for the float16 params.
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_float16
()
_multi_tensor_copy_this_to_that
(
this
=
main_data
,
that
=
model_data
,
overflow_buf
=
self
.
_dummy_overflow_buf
)
def
_copy_model_params_to_main_params
(
self
):
# log config
# Only needed for the float16 params.
self
.
_logger
=
get_dist_logger
()
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_float16
()
if
verbose
:
_multi_tensor_copy_this_to_that
(
this
=
model_data
,
that
=
main_data
,
self
.
_logger
.
info
(
overflow_buf
=
self
.
_dummy_overflow_buf
)
f
"
\n
========= FP16 Optimizer Config =========
\n
"
f
"Optimizer:
{
optimizer
.
__class__
.
__name__
}
\n
"
f
"clip_grad_norm =
{
clip_grad_norm
}
\n
"
f
"grad_scaler =
{
self
.
_grad_scaler
.
__class__
.
__name__
}
"
f
"=========================================="
,
ranks
=
[
0
])
def
reload_model_params
(
self
):
@
property
self
.
_copy_model_params_to_main_params
()
def
grad_scaler
(
self
):
return
self
.
_grad_scaler
@
torch
.
no_grad
()
@
property
def
step
(
self
):
def
loss_scale
(
self
):
# Copy gradients from model params to main params.
return
self
.
_grad_scaler
.
scale
self
.
_copy_model_grads_to_main_grads
()
# Do unscale, check for inf, and update grad scaler only for
@
property
# the case that grad scaler is provided.
def
optimizer
(
self
):
if
self
.
grad_scaler
:
return
self
.
_optimizer
@
property
def
defaults
(
self
):
return
self
.
_defaults
def
_check_overflow
(
self
):
# clear previous overflow record
self
.
_found_overflow
.
fill_
(
0.0
)
# Unscale and check for inf/nan.
# check for overflow
found_inf_flag
=
self
.
_unscale_main_grads_and_check_for_nan
()
for
group
in
self
.
_optimizer
.
param_groups
:
for
p
in
group
[
'params'
]:
if
has_inf_or_nan
(
p
.
grad
):
self
.
_found_overflow
.
fill_
(
1.0
)
break
# We are done with scaling gradients
# all-reduce across dp group
# so we can update the loss scale.
if
self
.
_dp_process_group
:
self
.
grad_scaler
.
update
(
found_inf_flag
)
dist
.
all_reduce
(
self
.
_found_overflow
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
self
.
_dp_process_group
)
# If we found inf/nan, skip the update.
# all-reduce over model parallel group
if
found_inf_flag
:
if
self
.
_mp_process_group
:
return
False
,
None
,
None
dist
.
all_reduce
(
self
.
_found_overflow
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
self
.
_mp_process_group
)
return
self
.
_found_overflow
.
item
()
>
0
def
zero_grad
(
self
,
set_to_none
=
True
):
# set_to_none = True can save some memory space
for
param_group
in
self
.
_optimizer
.
param_groups
:
zero_gard_by_list
(
param_group
[
'params'
],
set_to_none
=
set_to_none
)
def
_get_fp32_param_groups_to_update
(
self
):
return
self
.
_fp32_master_param_groups
+
self
.
_fp32_param_groups
def
_unscale_grads
(
self
):
for
group
in
self
.
_get_fp32_param_groups_to_update
():
for
p
in
group
:
if
p
.
grad
is
not
None
:
p
.
grad
.
data
.
div_
(
self
.
loss_scale
)
def
_assign_grad_to_fp32_master_param
(
self
):
# This only needs to be done for the float16 group.
for
fp16_param_group
,
fp32_master_param_group
in
zip
(
self
.
_fp16_param_groups
,
self
.
_fp32_master_param_groups
):
for
fp16_param
,
fp32_param
in
zip
(
fp16_param_group
,
fp32_master_param_group
):
fp32_param
.
grad
=
fp16_param
.
grad
.
float
()
# clear unneeded grad on fp16 param
fp16_param
.
grad
=
None
def
_update_fp16_param_from_fp32_param
(
self
):
fp16_param_data
=
[]
fp32_master_param_data
=
[]
for
fp16_group
,
fp32_group
in
zip
(
self
.
_fp16_param_groups
,
self
.
_fp32_master_param_groups
):
for
fp16_param
,
fp32_param
in
zip
(
fp16_group
,
fp32_group
):
fp16_param_data
.
append
(
fp16_param
.
data
)
fp32_master_param_data
.
append
(
fp32_param
.
data
)
_multi_tensor_copy_this_to_that
(
this
=
fp32_master_param_data
,
that
=
fp16_param_data
,
overflow_buf
=
self
.
_dummy_overflow_buf
)
def
step
(
self
):
# Copy gradients from model params to main params.
self
.
_assign_grad_to_fp32_master_param
()
self
.
_unscale_grads
()
overflow
=
self
.
_check_overflow
()
self
.
_grad_scaler
.
update
(
overflow
)
if
overflow
:
self
.
zero_grad
()
return
False
,
None
# Clip the main gradients.
# Clip the main gradients.
grad_norm
=
None
grad_norm
=
None
if
self
.
clip_grad
>
0.0
:
if
self
.
_clip_grad_max_norm
>
0.0
:
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
grad_norm
=
self
.
clip_grad_norm
(
self
.
_clip_grad_max_norm
)
# count the zeros in the grads
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
self
.
log_num_zeros_in_grad
else
None
# Step the optimizer.
# Step the optimizer.
self
.
optimizer
.
step
()
self
.
_
optimizer
.
step
()
# Update params from main params.
# Update params from main params.
self
.
_
copy_main_params_to_model
_param
s
()
self
.
_
update_fp16_param_from_fp32
_param
()
# Successful update.
# Successful update.
return
True
,
grad_norm
,
num_zeros_in_grad
return
True
,
grad_norm
def
backward
(
self
,
loss
):
scaled_loss
=
loss
*
self
.
grad_scaler
.
scale
scaled_loss
.
backward
()
def
state_dict
(
self
):
def
state_dict
(
self
):
state_dict
=
{}
state_dict
=
{}
state_dict
[
'optimizer'
]
=
self
.
optimizer
.
state_dict
()
state_dict
[
'optimizer'
]
=
self
.
_
optimizer
.
state_dict
()
if
self
.
grad_scaler
:
if
self
.
grad_scaler
:
state_dict
[
'grad_scaler'
]
=
self
.
grad_scaler
.
state_dict
()
state_dict
[
'grad_scaler'
]
=
self
.
grad_scaler
.
state_dict
()
state_dict
[
'fp32_
from_fp16_param
s'
]
=
self
.
fp32_
from_float16
_groups
state_dict
[
'fp32_
master_param_group
s'
]
=
self
.
_
fp32_
master_param
_groups
return
state_dict
return
state_dict
def
load_state_dict
(
self
,
state_dict
):
def
load_state_dict
(
self
,
state_dict
):
# Optimizer.
# Optimizer.
optimizer_key
=
'optimizer'
self
.
_optimizer
.
load_state_dict
(
state_dict
[
'optimizer'
])
if
optimizer_key
not
in
state_dict
:
optimizer_key
=
'optimizer_state_dict'
print_rank_0
(
'***WARNING*** loading optimizer from '
'an old checkpoint ...'
)
self
.
optimizer
.
load_state_dict
(
state_dict
[
optimizer_key
])
# Grad scaler.
# Grad scaler.
if
'grad_scaler'
not
in
state_dict
:
if
'grad_scaler'
in
state_dict
:
print_rank_0
(
'***WARNING*** found an old checkpoint, will not '
'load grad scaler ...'
)
else
:
if
self
.
grad_scaler
:
self
.
grad_scaler
.
load_state_dict
(
state_dict
[
'grad_scaler'
])
self
.
grad_scaler
.
load_state_dict
(
state_dict
[
'grad_scaler'
])
else
:
print_rank_0
(
'***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...'
)
# Copy data for the main params.
# Copy data for the main params.
fp32_from_float16_params_key
=
'fp32_from_fp16_params'
if
'fp32_master_param_groups'
in
state_dict
:
if
fp32_from_float16_params_key
not
in
state_dict
:
for
current_group
,
ckpt_group
in
zip
(
self
.
_fp32_master_param_groups
,
fp32_from_float16_params_key
=
'fp32_from_fp16'
state_dict
[
'fp32_master_param_groups'
]):
for
current_group
,
saved_group
in
zip
(
for
current_param
,
ckpt_param
in
zip
(
current_group
,
ckpt_group
):
self
.
fp32_from_float16_groups
,
current_param
.
data
.
copy_
(
ckpt_param
.
data
)
state_dict
[
fp32_from_float16_params_key
]):
for
current_param
,
saved_param
in
zip
(
current_group
,
saved_group
):
def
clip_grad_norm
(
self
,
clip_grad
):
current_param
.
data
.
copy_
(
saved_param
.
data
)
def
get_parameters
(
self
):
params
=
[]
params
=
[]
for
param_group
in
self
.
optimizer
.
param_groups
:
for
param_group
in
self
.
_
optimizer
.
param_groups
:
for
param
in
param_group
[
'params'
]:
for
param
in
param_group
[
'params'
]:
params
.
append
(
param
)
params
.
append
(
param
)
return
params
def
clip_grad_norm
(
self
,
clip_grad
):
params
=
self
.
get_parameters
()
return
clip_grad_norm_fp32
(
params
,
clip_grad
)
return
clip_grad_norm_fp32
(
params
,
clip_grad
)
def
count_zeros
(
self
):
params
=
self
.
get_parameters
()
return
count_zeros_fp32
(
params
)
def
scale_loss
(
self
,
loss
):
"""Simple scaling."""
return
self
.
get_loss_scale
()
*
loss
# Promote state so it can be retrieved or set via
# Promote state so it can be retrieved or set via
# "optimizer_instance.state"
# "optimizer_instance.state"
def
_get_state
(
self
):
def
_get_state
(
self
):
return
self
.
optimizer
.
state
return
self
.
_
optimizer
.
state
def
_set_state
(
self
,
value
):
def
_set_state
(
self
,
value
):
self
.
optimizer
.
state
=
value
self
.
_
optimizer
.
state
=
value
state
=
property
(
_get_state
,
_set_state
)
state
=
property
(
_get_state
,
_set_state
)
...
@@ -500,9 +405,9 @@ class FP16Optimizer(Optimizer):
...
@@ -500,9 +405,9 @@ class FP16Optimizer(Optimizer):
# "optimizer_instance.param_groups"
# "optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
# (for example, to adjust the learning rate)
def
_get_param_groups
(
self
):
def
_get_param_groups
(
self
):
return
self
.
optimizer
.
param_groups
return
self
.
_
optimizer
.
param_groups
def
_set_param_groups
(
self
,
value
):
def
_set_param_groups
(
self
,
value
):
self
.
optimizer
.
param_groups
=
value
self
.
_
optimizer
.
param_groups
=
value
param_groups
=
property
(
_get_param_groups
,
_set_param_groups
)
param_groups
=
property
(
_get_param_groups
,
_set_param_groups
)
colossalai/amp/naive_amp/_utils.py
0 → 100644
View file @
e79ea442
from
typing
import
List
from
torch
import
Tensor
def
has_inf_or_nan
(
tensor
):
try
:
# if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as tensor
# (which is true for some recent version of pytorch).
tensor_sum
=
float
(
tensor
.
float
().
sum
())
# More efficient version that can be used if .sum() returns a Python scalar
# tensor_sum = float(tensor.sum())
except
RuntimeError
as
instance
:
# We want to check if inst is actually an overflow exception.
# RuntimeError could come from a different error.
# If so, we still want the exception to propagate.
if
"value cannot be converted"
not
in
instance
.
args
[
0
]:
raise
return
True
else
:
if
tensor_sum
==
float
(
'inf'
)
or
tensor_sum
==
-
float
(
'inf'
)
or
tensor_sum
!=
tensor_sum
:
return
True
return
False
def
zero_gard_by_list
(
tensor_list
:
List
[
Tensor
],
set_to_none
:
bool
=
True
)
->
None
:
"""
Clear the gradient of a list of tensors,
Note: copied from torch.optim.optimizer.
"""
for
param
in
tensor_list
:
if
param
.
grad
is
not
None
:
if
set_to_none
:
param
.
grad
=
None
else
:
if
param
.
grad
.
grad_fn
is
not
None
:
param
.
grad
.
detach_
()
else
:
param
.
grad
.
requires_grad_
(
False
)
param
.
grad
.
zero_
()
colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
View file @
e79ea442
...
@@ -28,12 +28,10 @@ class BaseGradScaler(ABC):
...
@@ -28,12 +28,10 @@ class BaseGradScaler(ABC):
def
inv_scale
(
self
)
->
Tensor
:
def
inv_scale
(
self
)
->
Tensor
:
return
self
.
_scale
.
double
().
reciprocal
().
float
()
return
self
.
_scale
.
double
().
reciprocal
().
float
()
@
abstractmethod
def
state_dict
(
self
)
->
Dict
:
def
state_dict
(
self
)
->
Dict
:
state_dict
=
dict
()
state_dict
=
dict
()
state_dict
[
'scale'
]
=
self
.
scale
state_dict
[
'scale'
]
=
self
.
scale
@
abstractmethod
def
load_state_dict
(
self
,
state_dict
:
Dict
)
->
None
:
def
load_state_dict
(
self
,
state_dict
:
Dict
)
->
None
:
self
.
_scale
=
state_dict
[
'scale'
]
self
.
_scale
=
state_dict
[
'scale'
]
...
...
colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
View file @
e79ea442
...
@@ -16,11 +16,19 @@ class DynamicGradScaler(BaseGradScaler):
...
@@ -16,11 +16,19 @@ class DynamicGradScaler(BaseGradScaler):
growth_interval
:
int
=
1000
,
growth_interval
:
int
=
1000
,
min_scale
:
int
=
None
,
min_scale
:
int
=
None
,
max_scale
:
int
=
None
,
max_scale
:
int
=
None
,
hysteresis
:
int
=
None
,
hysteresis
:
int
=
2
,
verbose
:
bool
=
False
):
verbose
:
bool
=
False
):
super
().
__init__
(
initial_scale
,
verbose
)
super
().
__init__
(
initial_scale
,
verbose
)
self
.
_min_scale
=
min_scale
if
min_scale
:
self
.
_max_scale
=
max_scale
self
.
_min_scale
=
torch
.
cuda
.
FloatTensor
([
min_scale
])
else
:
self
.
_min_scale
=
None
if
max_scale
:
self
.
_max_scale
=
torch
.
cuda
.
FloatTensor
([
max_scale
])
else
:
self
.
_max_scale
=
None
self
.
_growth_factor
=
growth_factor
self
.
_growth_factor
=
growth_factor
self
.
_backoff_factor
=
backoff_factor
self
.
_backoff_factor
=
backoff_factor
self
.
_growth_interval
=
growth_interval
self
.
_growth_interval
=
growth_interval
...
...
colossalai/amp/naive_amp/naive_amp.py
View file @
e79ea442
...
@@ -26,17 +26,11 @@ class NaiveAMPOptimizer(ColossalaiOptimizer):
...
@@ -26,17 +26,11 @@ class NaiveAMPOptimizer(ColossalaiOptimizer):
"""
"""
def
__init__
(
self
,
optim
:
Optimizer
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
optim
:
Optimizer
,
*
args
,
**
kwargs
):
optim
=
FP16Optimizer
(
optimizer
=
optim
,
*
args
,
**
kwargs
)
optim
=
FP16Optimizer
(
optim
,
*
args
,
**
kwargs
)
super
().
__init__
(
optim
)
super
().
__init__
(
optim
)
def
backward
(
self
,
loss
:
Tensor
):
def
backward
(
self
,
loss
:
Tensor
):
"""Backward with gradient scaler
self
.
optim
.
backward
(
loss
)
:param loss: loss computed by a loss function
:type loss: torch.Tensor
"""
loss
=
self
.
optim
.
scale_loss
(
loss
)
loss
.
backward
()
def
step
(
self
):
def
step
(
self
):
return
self
.
optim
.
step
()
return
self
.
optim
.
step
()
...
...
colossalai/initialize.py
View file @
e79ea442
...
@@ -304,7 +304,7 @@ def initialize(model: nn.Module,
...
@@ -304,7 +304,7 @@ def initialize(model: nn.Module,
if
is_using_pp
():
if
is_using_pp
():
assert
amp_mode
==
AMP_TYPE
.
NAIVE
,
'Pipeline only support NaiveAMP currently'
assert
amp_mode
==
AMP_TYPE
.
NAIVE
,
'Pipeline only support NaiveAMP currently'
if
amp_mode
==
AMP_TYPE
.
NAIVE
:
if
amp_mode
==
AMP_TYPE
.
NAIVE
:
cfg_
[
'clip_grad'
]
=
clip_grad_norm
cfg_
[
'clip_grad
_norm
'
]
=
clip_grad_norm
model
,
optimizer
,
criterion
=
convert_to_amp
(
model
=
model
,
model
,
optimizer
,
criterion
=
convert_to_amp
(
model
=
model
,
optimizer
=
optimizer
,
optimizer
=
optimizer
,
criterion
=
criterion
,
criterion
=
criterion
,
...
...
colossalai/zero/sharded_optim/sharded_optim.py
View file @
e79ea442
from
itertools
import
groupby
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -7,7 +6,7 @@ from torch.optim import Optimizer
...
@@ -7,7 +6,7 @@ from torch.optim import Optimizer
from
.bookkeeping
import
ParameterStore
,
GradientStore
,
BucketStore
,
TensorBucket
from
.bookkeeping
import
ParameterStore
,
GradientStore
,
BucketStore
,
TensorBucket
from
colossalai.context
import
ParallelMode
from
colossalai.context
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.amp.naive_amp.
_fp16_optimiz
er
import
DynamicGradScaler
from
colossalai.amp.naive_amp.
grad_scal
er
import
DynamicGradScaler
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
._utils
import
(
move_tensor
,
flatten
,
get_grad_accumulate_object
,
split_half_float_double
,
reduce_tensor
,
from
._utils
import
(
move_tensor
,
flatten
,
get_grad_accumulate_object
,
split_half_float_double
,
reduce_tensor
,
release_param_grad
,
calculate_global_norm_from_list
,
compute_norm
,
sync_param
,
has_inf_or_nan
)
release_param_grad
,
calculate_global_norm_from_list
,
compute_norm
,
sync_param
,
has_inf_or_nan
)
...
@@ -16,11 +15,8 @@ from functools import partial
...
@@ -16,11 +15,8 @@ from functools import partial
class
ShardedOptimizer
(
ColossalaiOptimizer
):
class
ShardedOptimizer
(
ColossalaiOptimizer
):
def
__init__
(
def
__init__
(
self
,
self
,
optimizer
:
Optimizer
,
optimizer
:
Optimizer
,
# grad scaler config
initial_scale
=
2
**
32
,
initial_scale
=
2
**
32
,
min_scale
=
1
,
min_scale
=
1
,
growth_factor
=
2
,
growth_factor
=
2
,
...
@@ -28,23 +24,14 @@ class ShardedOptimizer(ColossalaiOptimizer):
...
@@ -28,23 +24,14 @@ class ShardedOptimizer(ColossalaiOptimizer):
growth_interval
=
1000
,
growth_interval
=
1000
,
hysteresis
=
2
,
hysteresis
=
2
,
max_scale
:
int
=
2
**
32
,
max_scale
:
int
=
2
**
32
,
# grad clipping
clip_grad_norm
=
2.0
,
clip_grad_norm
=
2.0
,
verbose
=
False
,
verbose
=
False
,
# communication
reduce_bucket_size
=
500000000
,
reduce_bucket_size
=
500000000
,
communication_dtype
=
torch
.
float16
,
communication_dtype
=
torch
.
float16
,
overlap_communication
=
False
,
overlap_communication
=
False
,
# stage 2
partition_grad
=
False
,
partition_grad
=
False
,
dp_parallel_mode
=
ParallelMode
.
DATA
,
dp_parallel_mode
=
ParallelMode
.
DATA
,
mp_parallel_mode
=
ParallelMode
.
MODEL
,
mp_parallel_mode
=
ParallelMode
.
MODEL
,
# cpu offload
cpu_offload
=
False
,
cpu_offload
=
False
,
cpu_fp16_param
=
False
,
cpu_fp16_param
=
False
,
cpu_fp16_grad
=
False
):
cpu_fp16_grad
=
False
):
...
@@ -263,6 +250,7 @@ class ShardedOptimizer(ColossalaiOptimizer):
...
@@ -263,6 +250,7 @@ class ShardedOptimizer(ColossalaiOptimizer):
# args here is not grad, but allow_unreacable and accumulate_grad
# args here is not grad, but allow_unreacable and accumulate_grad
def
reduce_grad_hook
(
*
args
):
def
reduce_grad_hook
(
*
args
):
reduction_func
()
reduction_func
()
accum_grad_obj
.
register_hook
(
reduce_grad_hook
)
accum_grad_obj
.
register_hook
(
reduce_grad_hook
)
_define_and_attach
(
param
,
reduce_rank
)
_define_and_attach
(
param
,
reduce_rank
)
...
@@ -444,7 +432,6 @@ class ShardedOptimizer(ColossalaiOptimizer):
...
@@ -444,7 +432,6 @@ class ShardedOptimizer(ColossalaiOptimizer):
self
.
_grad_store
.
_averaged_gradients
[
group_id
]
=
[]
self
.
_grad_store
.
_averaged_gradients
[
group_id
]
=
[]
self
.
_grad_store
.
_averaged_gradients
[
group_id
]
=
[]
self
.
_grad_store
.
_averaged_gradients
[
group_id
]
=
[]
# unscale and clip grads
# unscale and clip grads
global_norm
=
calculate_global_norm_from_list
(
norm_list
=
norm_groups
)
global_norm
=
calculate_global_norm_from_list
(
norm_list
=
norm_groups
)
self
.
_unscale_and_clip_grads
(
single_grad_partition_groups
,
global_norm
)
self
.
_unscale_and_clip_grads
(
single_grad_partition_groups
,
global_norm
)
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
e79ea442
...
@@ -4,7 +4,7 @@ from typing import Callable, Dict, Optional, Union
...
@@ -4,7 +4,7 @@ from typing import Callable, Dict, Optional, Union
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torch.nn
as
nn
from
colossalai.amp.naive_amp.
_fp16_optimiz
er
import
DynamicGradScaler
from
colossalai.amp.naive_amp.
grad_scal
er
import
DynamicGradScaler
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
...
...
tests/test_amp/test_naive_fp16.py
0 → 100644
View file @
e79ea442
import
torch
import
colossalai
import
copy
import
pytest
import
torch.multiprocessing
as
mp
from
colossalai.amp
import
convert_to_naive_amp
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
colossalai.utils
import
free_port
from
functools
import
partial
def
check_equal
(
a
,
b
):
"""
This function checks if two tensors are equal within tolerance
"""
assert
torch
.
allclose
(
a
.
float
(),
b
.
float
(),
rtol
=
1e-4
,
atol
=
1e-3
),
f
'a =
{
a
}
, b =
{
b
}
'
def
run_naive_amp
():
"""
In this test, we compare the naive fp16 optimizer implemented in colossalai
and fp32 torch optimizer
"""
# create layer
test_models
=
[
'repeated_computed_layers'
,
'nested_model'
]
for
test_name
in
test_models
:
get_component_func
=
non_distributed_component_funcs
.
get_callable
(
test_name
)
model_builder
,
train_dataloader
,
_
,
optim_builder
,
_
=
get_component_func
()
# create model
amp_model
=
model_builder
(
checkpoint
=
True
).
cuda
()
torch_model
=
copy
.
deepcopy
(
amp_model
)
# create optimizer
amp_optimizer
=
optim_builder
(
amp_model
)
torch_optimizer
=
optim_builder
(
torch_model
)
# inject naive amp
amp_config
=
dict
(
initial_scale
=
1
)
amp_model
,
amp_optimizer
=
convert_to_naive_amp
(
amp_model
,
amp_optimizer
,
amp_config
)
# create data
data_iter
=
iter
(
train_dataloader
)
data
,
label
=
next
(
data_iter
)
data
=
data
.
cuda
()
# forward pass
amp_output
=
amp_model
(
data
)
torch_output
=
torch_model
(
data
)
assert
torch
.
allclose
(
amp_output
,
torch_output
,
rtol
=
1e-3
,
atol
=
1e-3
),
f
'
{
amp_output
}
vs
{
torch_output
}
'
# backward
amp_optimizer
.
backward
(
amp_output
.
mean
())
torch_output
.
mean
().
backward
()
# check grad
for
amp_param
,
torch_param
in
zip
(
amp_model
.
parameters
(),
torch_model
.
parameters
()):
torch
.
allclose
(
amp_param
.
grad
,
torch_param
.
grad
.
half
(),
rtol
=
1e-3
,
atol
=
1e-3
)
# step
amp_optimizer
.
step
()
torch_optimizer
.
step
()
# check updated param
for
amp_param
,
torch_param
in
zip
(
amp_model
.
parameters
(),
torch_model
.
parameters
()):
torch
.
allclose
(
amp_param
,
torch_param
.
half
(),
rtol
=
1e-3
,
atol
=
1e-3
)
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
port
=
port
,
host
=
'localhost'
)
run_naive_amp
()
@
pytest
.
mark
.
dist
def
test_naive_amp
():
world_size
=
1
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_naive_amp
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment