Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e79ea442
Unverified
Commit
e79ea442
authored
Mar 15, 2022
by
Frank Lee
Committed by
GitHub
Mar 15, 2022
Browse files
[fp16] refactored fp16 optimizer (#392)
parent
f8a0e7fb
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
380 additions
and
354 deletions
+380
-354
colossalai/amp/naive_amp/__init__.py
colossalai/amp/naive_amp/__init__.py
+16
-5
colossalai/amp/naive_amp/_fp16_optimizer.py
colossalai/amp/naive_amp/_fp16_optimizer.py
+198
-293
colossalai/amp/naive_amp/_utils.py
colossalai/amp/naive_amp/_utils.py
+40
-0
colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
+0
-2
colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
+11
-3
colossalai/amp/naive_amp/naive_amp.py
colossalai/amp/naive_amp/naive_amp.py
+2
-8
colossalai/initialize.py
colossalai/initialize.py
+1
-1
colossalai/zero/sharded_optim/sharded_optim.py
colossalai/zero/sharded_optim/sharded_optim.py
+28
-41
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+1
-1
tests/test_amp/test_naive_fp16.py
tests/test_amp/test_naive_fp16.py
+83
-0
No files found.
colossalai/amp/naive_amp/__init__.py
View file @
e79ea442
import
inspect
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
colossalai.utils
import
is_no_pp_or_last_stage
from
colossalai.utils
import
is_no_pp_or_last_stage
from
.naive_amp
import
NaiveAMPOptimizer
,
NaiveAMPModel
from
.naive_amp
import
NaiveAMPOptimizer
,
NaiveAMPModel
from
.grad_scaler
import
DynamicGradScaler
,
ConstantGradScaler
def
convert_to_naive_amp
(
model
:
nn
.
Module
,
def
convert_to_naive_amp
(
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
amp_config
):
optimizer
:
Optimizer
,
amp_config
):
"""A helper function to wrap training components with naive AMP modules
"""A helper function to wrap training components with naive AMP modules
:param model: your model object
:param model: your model object
...
@@ -31,7 +30,19 @@ def convert_to_naive_amp(model: nn.Module,
...
@@ -31,7 +30,19 @@ def convert_to_naive_amp(model: nn.Module,
output_to_fp32
=
is_no_pp_or_last_stage
()
output_to_fp32
=
is_no_pp_or_last_stage
()
model
=
NaiveAMPModel
(
model
,
output_to_fp32
=
output_to_fp32
)
model
=
NaiveAMPModel
(
model
,
output_to_fp32
=
output_to_fp32
)
optimizer
=
NaiveAMPOptimizer
(
optimizer
,
**
amp_config
)
use_dynamic_grad_scaler
=
amp_config
.
pop
(
'dynamic_grad_scale'
,
True
)
if
use_dynamic_grad_scaler
:
scaler_class
=
DynamicGradScaler
else
:
scaler_class
=
ConstantGradScaler
sig
=
inspect
.
signature
(
scaler_class
.
__init__
)
kwargs
=
dict
()
for
param
in
sig
.
parameters
.
values
():
if
param
.
name
in
amp_config
:
kwargs
[
param
.
name
]
=
amp_config
.
pop
(
param
.
name
)
grad_scaler
=
scaler_class
(
**
kwargs
)
optimizer
=
NaiveAMPOptimizer
(
optimizer
,
grad_scaler
,
**
amp_config
)
return
model
,
optimizer
return
model
,
optimizer
...
...
colossalai/amp/naive_amp/_fp16_optimizer.py
View file @
e79ea442
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# -*- encoding: utf-8 -*-
# -*- encoding: utf-8 -*-
import
torch
import
torch
import
torch.distributed
as
dist
try
:
try
:
import
colossal_C
import
colossal_C
...
@@ -9,41 +10,30 @@ except:
...
@@ -9,41 +10,30 @@ except:
print
(
'Colossalai should be built with cuda extension to use the FP16 optimizer'
)
print
(
'Colossalai should be built with cuda extension to use the FP16 optimizer'
)
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context
import
ParallelMode
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
(
print_rank_0
,
copy_tensor_parallel_attributes
,
from
colossalai.utils
import
(
copy_tensor_parallel_attributes
,
clip_grad_norm_fp32
,
multi_tensor_applier
)
clip_grad_norm_fp32
,
count_zeros_fp32
,
multi_tensor_applier
)
from
torch.distributed
import
ProcessGroup
from
.grad_scaler
import
BaseGradScaler
from
._utils
import
has_inf_or_nan
,
zero_gard_by_list
def
_zero_grad_group_helper
(
group
,
set_to_none
):
__all__
=
[
'FP16Optimizer'
]
"""Zero out the gradient for a group of parameters.
Note: copied from torch.optim.optimizer."""
for
param
in
group
:
if
param
.
grad
is
not
None
:
if
set_to_none
:
param
.
grad
=
None
else
:
if
param
.
grad
.
grad_fn
is
not
None
:
param
.
grad
.
detach_
()
else
:
param
.
grad
.
requires_grad_
(
False
)
param
.
grad
.
zero_
()
def
_multi_tensor_copy_this_to_that
(
this
,
that
,
overflow_buf
=
None
):
def
_multi_tensor_copy_this_to_that
(
this
,
that
,
overflow_buf
=
None
):
"""Use multi-tensor-applier to copy values from one list to another.
"""
adapted from Megatron-LM (https://github.com/NVIDIA/Megatron-LM)
Use multi-tensor-applier to copy values from one list to another.
We don't have a blfoat16 implementation so for now if the overflow_buf
We don't have a blfoat16 implementation so for now if the overflow_buf
is not provided, we default back to simple loop copy to be compatible
is not provided, we default back to simple loop copy to be compatible
with bfloat16."""
with bfloat16.
"""
if
overflow_buf
:
if
overflow_buf
:
overflow_buf
.
fill_
(
0
)
overflow_buf
.
fill_
(
0
)
# Scaling with factor `1.0` is equivalent to copy.
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier
(
colossal_C
.
multi_tensor_scale
,
multi_tensor_applier
(
colossal_C
.
multi_tensor_scale
,
overflow_buf
,
[
this
,
that
],
1.0
)
overflow_buf
,
[
this
,
that
],
1.0
)
else
:
else
:
for
this_
,
that_
in
zip
(
this
,
that
):
for
this_
,
that_
in
zip
(
this
,
that
):
that_
.
copy_
(
this_
)
that_
.
copy_
(
this_
)
...
@@ -111,8 +101,7 @@ class DynamicGradScaler:
...
@@ -111,8 +101,7 @@ class DynamicGradScaler:
self
.
_hysteresis_tracker
-=
1
self
.
_hysteresis_tracker
-=
1
# Now if we are out of hysteresis count, scale down the loss.
# Now if we are out of hysteresis count, scale down the loss.
if
self
.
_hysteresis_tracker
<=
0
:
if
self
.
_hysteresis_tracker
<=
0
:
self
.
_scale
=
torch
.
max
(
self
.
_scale
*
self
.
backoff_factor
,
self
.
_scale
=
torch
.
max
(
self
.
_scale
*
self
.
backoff_factor
,
self
.
min_scale
)
self
.
min_scale
)
if
self
.
verbose
:
if
self
.
verbose
:
self
.
_logger
.
info
(
f
'overflow occurs, loss scale is adjusted to
{
self
.
_scale
}
'
,
ranks
=
[
0
])
self
.
_logger
.
info
(
f
'overflow occurs, loss scale is adjusted to
{
self
.
_scale
}
'
,
ranks
=
[
0
])
else
:
else
:
...
@@ -127,12 +116,13 @@ class DynamicGradScaler:
...
@@ -127,12 +116,13 @@ class DynamicGradScaler:
if
self
.
_max_scale
is
not
None
and
self
.
_scale
>=
self
.
_max_scale
:
if
self
.
_max_scale
is
not
None
and
self
.
_scale
>=
self
.
_max_scale
:
if
self
.
verbose
:
if
self
.
verbose
:
self
.
_logger
.
info
(
self
.
_logger
.
info
(
f
'Current loss scale
{
self
.
_scale
}
has reached the max scale
{
self
.
_max_scale
}
allowed'
,
ranks
=
[
0
])
f
'Current loss scale
{
self
.
_scale
}
has reached the max scale
{
self
.
_max_scale
}
allowed'
,
ranks
=
[
0
])
else
:
else
:
self
.
_scale
=
self
.
_scale
*
self
.
growth_factor
self
.
_scale
=
self
.
_scale
*
self
.
growth_factor
if
self
.
verbose
:
if
self
.
verbose
:
self
.
_logger
.
info
(
self
.
_logger
.
info
(
f
'no consecutive overflow, loss scale is adjusted to
{
self
.
_scale
}
'
,
f
'no consecutive overflow, loss scale is adjusted to
{
self
.
_scale
}
'
,
ranks
=
[
0
])
ranks
=
[
0
])
def
state_dict
(
self
):
def
state_dict
(
self
):
state_dict
=
{}
state_dict
=
{}
...
@@ -173,326 +163,241 @@ class FP16Optimizer(Optimizer):
...
@@ -173,326 +163,241 @@ class FP16Optimizer(Optimizer):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
optimizer
,
optimizer
:
Optimizer
,
clip_grad
=
0
,
grad_scaler
:
BaseGradScaler
,
log_num_zeros_in_grad
=
False
,
verbose
:
bool
=
False
,
initial_scale
=
2
**
32
,
clip_grad_norm
=
0
,
min_scale
=
1
,
dp_process_group
:
ProcessGroup
=
None
,
growth_factor
=
2
,
mp_process_group
:
ProcessGroup
=
None
):
backoff_factor
=
0.5
,
growth_interval
=
1000
,
hysteresis
=
2
,
max_scale
:
int
=
2
**
32
,
verbose
:
bool
=
False
):
# default args for compatibility
bf16
=
False
params_have_main_grad
=
False
# have a defaults for compatibility with pytorch optim
# have a defaults for compatibility with pytorch optim
self
.
defaults
=
optimizer
.
defaults
self
.
_optimizer
=
optimizer
self
.
_defaults
=
optimizer
.
defaults
# log config
self
.
_logger
=
get_dist_logger
()
# fp16-related params
if
verbose
:
assert
isinstance
(
grad_scaler
,
BaseGradScaler
)
self
.
_logger
.
info
(
f
"
\n
========= FP16 Optimizer Config =========
\n
"
self
.
_grad_scaler
=
grad_scaler
f
"Optimizer:
{
optimizer
.
__class__
.
__name__
}
\n
"
self
.
_found_overflow
=
torch
.
cuda
.
FloatTensor
([
0.0
])
f
"clip_grad =
{
clip_grad
}
\n
"
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
f
"log_num_zeros_in_grad =
{
log_num_zeros_in_grad
}
\n
"
f
"initial_scale =
{
initial_scale
}
\n
"
# misc params
f
"min_scale =
{
min_scale
}
\n
"
self
.
_clip_grad_max_norm
=
clip_grad_norm
f
"growth_factor =
{
growth_factor
}
\n
"
f
"backoff_factor =
{
backoff_factor
}
\n
"
# get process group
f
"growth_interval =
{
growth_interval
}
\n
"
def
_get_process_group
(
parallel_mode
):
f
"hysteresis =
{
hysteresis
}
\n
"
if
gpc
.
is_initialized
(
ParallelMode
.
DATA
)
and
gpc
.
get_world_size
(
ParallelMode
.
DATA
):
f
"=========================================="
,
ranks
=
[
0
])
return
gpc
.
get_group
(
ParallelMode
.
DATA
)
else
:
"""Input optimizer is the base optimizer for example Adam."""
return
None
self
.
optimizer
=
optimizer
assert
self
.
optimizer
,
'no optimizer is provided.'
if
dp_process_group
is
None
:
# Set gradient clipping and logging params.
dp_process_group
=
_get_process_group
(
ParallelMode
.
DATA
)
self
.
clip_grad
=
clip_grad
if
mp_process_group
is
None
:
self
.
log_num_zeros_in_grad
=
log_num_zeros_in_grad
mp_process_group
=
_get_process_group
(
ParallelMode
.
MODEL
)
self
.
params_have_main_grad
=
params_have_main_grad
self
.
_dp_process_group
=
dp_process_group
self
.
bf16
=
bf16
self
.
_mp_process_group
=
mp_process_group
self
.
grad_scaler
=
DynamicGradScaler
(
initial_scale
=
initial_scale
,
# we maintain three groups of parameters
min_scale
=
min_scale
,
# so that the model can have a mixture
growth_factor
=
growth_factor
,
# of fp16 and fp32 params
backoff_factor
=
backoff_factor
,
# fp16_param_groups: the fp16 params of the model
growth_interval
=
growth_interval
,
# fp32_master_param_groups: the fp32 params cast from the fp16 param of the model
hysteresis
=
hysteresis
,
# fp32_param_groups: the fp32 params of the model
max_scale
=
max_scale
,
# NOTE:
verbose
=
verbose
# 1. fp16_param_groups and fp32_master_param_groups have one-to-one correspondence
)
# 2. fp32_param_groups and fp16_param_groups are exclusive of each other
self
.
_fp16_param_groups
=
[]
# None grad scaler is only supported for bf16.
self
.
_fp32_master_param_groups
=
[]
if
self
.
grad_scaler
is
None
:
self
.
_fp32_param_groups
=
[]
assert
self
.
bf16
,
'fp16 expects a grad scaler.'
# Tensor used to determine if a nan/if has happend.
# Any non-zero value indicates inf/nan.
# Note that we keep this for the cases that grad scaler is none.
# We still record nan/inf if we have a bfloat16 with a grad scaler.
if
self
.
grad_scaler
:
self
.
found_inf
=
torch
.
cuda
.
FloatTensor
([
0.0
])
# Dummy tensor needed for apex multi-apply tensor.
# For bfloat, we don't have multi-tensor apply and for now
# we set it to none so the multi-tensor apply gets ignored.
if
bf16
:
self
.
_dummy_overflow_buf
=
None
else
:
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
# In case grad scaler is not passed, define the unity scale.
if
self
.
grad_scaler
is
None
:
self
.
_scale_one
=
torch
.
cuda
.
FloatTensor
([
1.0
])
# ======================
# main parameter stuff
# ======================
# Three groups of parameters:
# float16_groups: original float16 parameters
# fp32_from_float16_groups: fp32 copy of float16 parameters
# fp32_from_fp32_groups: original fp32 parameters
self
.
float16_groups
=
[]
self
.
fp32_from_float16_groups
=
[]
self
.
fp32_from_fp32_groups
=
[]
# For all the groups in the original optimizer:
# For all the groups in the original optimizer:
for
param_group
in
self
.
optimizer
.
param_groups
:
for
param_group
in
self
.
_
optimizer
.
param_groups
:
f
loat
16_params
_this_group
=
[]
f
p
16_params
=
[]
fp32_
params_this_group
=
[]
fp32_
master_params
=
[]
fp32_
from_float16_params_this_group
=
[]
fp32_
params
=
[]
# For all the parameters in this group:
# For all the parameters in this group:
for
i
,
param
in
enumerate
(
param_group
[
'params'
]):
for
i
,
param
in
enumerate
(
param_group
[
'params'
]):
if
param
.
requires_grad
:
if
param
.
requires_grad
:
# float16 params:
# float16 params:
if
param
.
type
()
in
[
'torch.cuda.HalfTensor'
,
if
param
.
type
()
in
[
'torch.cuda.HalfTensor'
]:
'torch.cuda.BFloat16Tensor'
]:
fp16_params
.
append
(
param
)
float16_params_this_group
.
append
(
param
)
# Create a copy
main_param
=
param
.
detach
().
clone
().
float
()
# Copy tensor model parallel attributes.
copy_tensor_parallel_attributes
(
param
,
main_param
)
# if hasattr(param, 'shared'):
# Create a fp32 copy
# main_param.shared = param.shared
fp32_param
=
param
.
detach
().
clone
().
float
()
# Copy tensor model parallel attributes.
copy_tensor_parallel_attributes
(
param
,
fp32_param
)
# Replace the optimizer params with the new fp32 copy.
# Replace the optimizer params with the new fp32 copy.
param_group
[
'params'
][
i
]
=
main_param
param_group
[
'params'
][
i
]
=
fp32_param
fp32_from_float16_params_this_group
.
append
(
main_param
)
fp32_master_params
.
append
(
fp32_param
)
# Reset existing state dict key to the new main param.
# Reset existing state dict key to the new main param.
if
param
in
self
.
optimizer
.
state
:
if
param
in
self
.
_optimizer
.
state
:
self
.
optimizer
.
state
[
main_param
]
\
self
.
_optimizer
.
state
[
fp32_param
]
=
self
.
_optimizer
.
state
.
pop
(
param
)
=
self
.
optimizer
.
state
.
pop
(
param
)
# fp32 params.
# fp32 params.
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
fp32_params_this_group
.
append
(
param
)
fp32_params
.
append
(
param
)
param_group
[
'params'
][
i
]
=
param
else
:
else
:
raise
TypeError
(
'Wrapped parameters must be one of '
raise
TypeError
(
'Expected parameter of type torch.cuda.FloatTensor '
'torch.cuda.FloatTensor, '
f
'or torch.cuda.HalfTensor, but got
{
param
.
type
()
}
'
)
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'
.
format
(
param
.
type
()))
self
.
float16_groups
.
append
(
float16_params_this_group
)
self
.
_fp16_param_groups
.
append
(
fp16_params
)
self
.
fp32_from_float16_groups
.
append
(
self
.
_fp32_master_param_groups
.
append
(
fp32_master_params
)
fp32_from_float16_params_this_group
)
self
.
_fp32_param_groups
.
append
(
fp32_params
)
self
.
fp32_from_fp32_groups
.
append
(
fp32_params_this_group
)
# Leverage state_dict() and load_state_dict() to
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
# recast preexisting per-param state tensors
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
self
.
_optimizer
.
load_state_dict
(
self
.
_optimizer
.
state_dict
())
def
zero_grad
(
self
,
set_to_none
=
False
):
"""We only need to zero the model related parameters, i.e.,
float16_groups & fp32_from_fp32_groups."""
for
group
in
self
.
float16_groups
:
_zero_grad_group_helper
(
group
,
set_to_none
)
for
group
in
self
.
fp32_from_fp32_groups
:
_zero_grad_group_helper
(
group
,
set_to_none
)
def
get_loss_scale
(
self
):
if
self
.
grad_scaler
is
None
:
return
self
.
_scale_one
return
self
.
grad_scaler
.
scale
def
_copy_model_grads_to_main_grads
(
self
):
# This only needs to be done for the float16 group.
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
self
.
fp32_from_float16_groups
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
if
self
.
params_have_main_grad
:
main_param
.
grad
=
model_param
.
main_grad
.
float
()
else
:
if
model_param
.
grad
is
not
None
:
main_param
.
grad
=
model_param
.
grad
.
float
()
# For fp32 grads, we need to reset the grads to main grad.
if
self
.
params_have_main_grad
:
for
model_group
in
self
.
fp32_from_fp32_groups
:
for
model_param
in
model_group
:
model_param
.
grad
=
model_param
.
main_grad
def
_unscale_main_grads_and_check_for_nan
(
self
):
main_grads
=
[]
# fp32 params fromm float16 ones.
for
main_group
in
self
.
fp32_from_float16_groups
:
for
main_param
in
main_group
:
if
main_param
.
grad
is
not
None
:
main_grads
.
append
(
main_param
.
grad
.
data
)
# Append fp32 parameters.
for
main_group
in
self
.
fp32_from_fp32_groups
:
for
main_param
in
main_group
:
if
main_param
.
grad
is
not
None
:
main_grads
.
append
(
main_param
.
grad
.
data
)
# Reset found inf.
self
.
found_inf
.
fill_
(
0.0
)
# Unscale and set found inf/nan
torch
.
_amp_foreach_non_finite_check_and_unscale_
(
main_grads
,
self
.
found_inf
,
self
.
grad_scaler
.
inv_scale
)
# Update across all model parallel instances.
torch
.
distributed
.
all_reduce
(
self
.
found_inf
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
gpc
.
get_group
(
ParallelMode
.
MODEL
))
# Check for nan.
found_inf_flag
=
(
self
.
found_inf
.
item
()
>
0
)
return
found_inf_flag
def
_get_model_and_main_params_data_float16
(
self
):
model_data
=
[]
main_data
=
[]
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
self
.
fp32_from_float16_groups
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
model_data
.
append
(
model_param
.
data
)
main_data
.
append
(
main_param
.
data
)
return
model_data
,
main_data
def
_copy_main_params_to_model_params
(
self
):
# Only needed for the float16 params.
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_float16
()
_multi_tensor_copy_this_to_that
(
this
=
main_data
,
that
=
model_data
,
overflow_buf
=
self
.
_dummy_overflow_buf
)
def
_copy_model_params_to_main_params
(
self
):
# log config
# Only needed for the float16 params.
self
.
_logger
=
get_dist_logger
()
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_float16
()
if
verbose
:
_multi_tensor_copy_this_to_that
(
this
=
model_data
,
that
=
main_data
,
self
.
_logger
.
info
(
overflow_buf
=
self
.
_dummy_overflow_buf
)
f
"
\n
========= FP16 Optimizer Config =========
\n
"
f
"Optimizer:
{
optimizer
.
__class__
.
__name__
}
\n
"
f
"clip_grad_norm =
{
clip_grad_norm
}
\n
"
f
"grad_scaler =
{
self
.
_grad_scaler
.
__class__
.
__name__
}
"
f
"=========================================="
,
ranks
=
[
0
])
def
reload_model_params
(
self
):
@
property
self
.
_copy_model_params_to_main_params
()
def
grad_scaler
(
self
):
return
self
.
_grad_scaler
@
torch
.
no_grad
()
@
property
def
step
(
self
):
def
loss_scale
(
self
):
# Copy gradients from model params to main params.
return
self
.
_grad_scaler
.
scale
self
.
_copy_model_grads_to_main_grads
()
# Do unscale, check for inf, and update grad scaler only for
@
property
# the case that grad scaler is provided.
def
optimizer
(
self
):
if
self
.
grad_scaler
:
return
self
.
_optimizer
@
property
def
defaults
(
self
):
return
self
.
_defaults
# Unscale and check for inf/nan.
def
_check_overflow
(
self
):
found_inf_flag
=
self
.
_unscale_main_grads_and_check_for_nan
()
# clear previous overflow record
self
.
_found_overflow
.
fill_
(
0.0
)
# We are done with scaling gradients
# check for overflow
# so we can update the loss scale.
for
group
in
self
.
_optimizer
.
param_groups
:
self
.
grad_scaler
.
update
(
found_inf_flag
)
for
p
in
group
[
'params'
]:
if
has_inf_or_nan
(
p
.
grad
):
self
.
_found_overflow
.
fill_
(
1.0
)
break
# If we found inf/nan, skip the update.
# all-reduce across dp group
if
found_inf_flag
:
if
self
.
_dp_process_group
:
return
False
,
None
,
None
dist
.
all_reduce
(
self
.
_found_overflow
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
self
.
_dp_process_group
)
# all-reduce over model parallel group
if
self
.
_mp_process_group
:
dist
.
all_reduce
(
self
.
_found_overflow
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
self
.
_mp_process_group
)
return
self
.
_found_overflow
.
item
()
>
0
def
zero_grad
(
self
,
set_to_none
=
True
):
# set_to_none = True can save some memory space
for
param_group
in
self
.
_optimizer
.
param_groups
:
zero_gard_by_list
(
param_group
[
'params'
],
set_to_none
=
set_to_none
)
def
_get_fp32_param_groups_to_update
(
self
):
return
self
.
_fp32_master_param_groups
+
self
.
_fp32_param_groups
def
_unscale_grads
(
self
):
for
group
in
self
.
_get_fp32_param_groups_to_update
():
for
p
in
group
:
if
p
.
grad
is
not
None
:
p
.
grad
.
data
.
div_
(
self
.
loss_scale
)
def
_assign_grad_to_fp32_master_param
(
self
):
# This only needs to be done for the float16 group.
for
fp16_param_group
,
fp32_master_param_group
in
zip
(
self
.
_fp16_param_groups
,
self
.
_fp32_master_param_groups
):
for
fp16_param
,
fp32_param
in
zip
(
fp16_param_group
,
fp32_master_param_group
):
fp32_param
.
grad
=
fp16_param
.
grad
.
float
()
# clear unneeded grad on fp16 param
fp16_param
.
grad
=
None
def
_update_fp16_param_from_fp32_param
(
self
):
fp16_param_data
=
[]
fp32_master_param_data
=
[]
for
fp16_group
,
fp32_group
in
zip
(
self
.
_fp16_param_groups
,
self
.
_fp32_master_param_groups
):
for
fp16_param
,
fp32_param
in
zip
(
fp16_group
,
fp32_group
):
fp16_param_data
.
append
(
fp16_param
.
data
)
fp32_master_param_data
.
append
(
fp32_param
.
data
)
_multi_tensor_copy_this_to_that
(
this
=
fp32_master_param_data
,
that
=
fp16_param_data
,
overflow_buf
=
self
.
_dummy_overflow_buf
)
def
step
(
self
):
# Copy gradients from model params to main params.
self
.
_assign_grad_to_fp32_master_param
()
self
.
_unscale_grads
()
overflow
=
self
.
_check_overflow
()
self
.
_grad_scaler
.
update
(
overflow
)
if
overflow
:
self
.
zero_grad
()
return
False
,
None
# Clip the main gradients.
# Clip the main gradients.
grad_norm
=
None
grad_norm
=
None
if
self
.
clip_grad
>
0.0
:
if
self
.
_clip_grad_max_norm
>
0.0
:
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
grad_norm
=
self
.
clip_grad_norm
(
self
.
_clip_grad_max_norm
)
# count the zeros in the grads
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
self
.
log_num_zeros_in_grad
else
None
# Step the optimizer.
# Step the optimizer.
self
.
optimizer
.
step
()
self
.
_
optimizer
.
step
()
# Update params from main params.
# Update params from main params.
self
.
_
copy_main_params_to_model
_param
s
()
self
.
_
update_fp16_param_from_fp32
_param
()
# Successful update.
# Successful update.
return
True
,
grad_norm
,
num_zeros_in_grad
return
True
,
grad_norm
def
backward
(
self
,
loss
):
scaled_loss
=
loss
*
self
.
grad_scaler
.
scale
scaled_loss
.
backward
()
def
state_dict
(
self
):
def
state_dict
(
self
):
state_dict
=
{}
state_dict
=
{}
state_dict
[
'optimizer'
]
=
self
.
optimizer
.
state_dict
()
state_dict
[
'optimizer'
]
=
self
.
_
optimizer
.
state_dict
()
if
self
.
grad_scaler
:
if
self
.
grad_scaler
:
state_dict
[
'grad_scaler'
]
=
self
.
grad_scaler
.
state_dict
()
state_dict
[
'grad_scaler'
]
=
self
.
grad_scaler
.
state_dict
()
state_dict
[
'fp32_
from_fp16_param
s'
]
=
self
.
fp32_
from_float16
_groups
state_dict
[
'fp32_
master_param_group
s'
]
=
self
.
_
fp32_
master_param
_groups
return
state_dict
return
state_dict
def
load_state_dict
(
self
,
state_dict
):
def
load_state_dict
(
self
,
state_dict
):
# Optimizer.
# Optimizer.
optimizer_key
=
'optimizer'
self
.
_optimizer
.
load_state_dict
(
state_dict
[
'optimizer'
])
if
optimizer_key
not
in
state_dict
:
optimizer_key
=
'optimizer_state_dict'
print_rank_0
(
'***WARNING*** loading optimizer from '
'an old checkpoint ...'
)
self
.
optimizer
.
load_state_dict
(
state_dict
[
optimizer_key
])
# Grad scaler.
# Grad scaler.
if
'grad_scaler'
not
in
state_dict
:
if
'grad_scaler'
in
state_dict
:
print_rank_0
(
'***WARNING*** found an old checkpoint, will not '
self
.
grad_scaler
.
load_state_dict
(
state_dict
[
'grad_scaler'
])
'load grad scaler ...'
)
else
:
if
self
.
grad_scaler
:
self
.
grad_scaler
.
load_state_dict
(
state_dict
[
'grad_scaler'
])
else
:
print_rank_0
(
'***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...'
)
# Copy data for the main params.
# Copy data for the main params.
fp32_from_float16_params_key
=
'fp32_from_fp16_params'
if
'fp32_master_param_groups'
in
state_dict
:
if
fp32_from_float16_params_key
not
in
state_dict
:
for
current_group
,
ckpt_group
in
zip
(
self
.
_fp32_master_param_groups
,
fp32_from_float16_params_key
=
'fp32_from_fp16'
state_dict
[
'fp32_master_param_groups'
]):
for
current_group
,
saved_group
in
zip
(
for
current_param
,
ckpt_param
in
zip
(
current_group
,
ckpt_group
):
self
.
fp32_from_float16_groups
,
current_param
.
data
.
copy_
(
ckpt_param
.
data
)
state_dict
[
fp32_from_float16_params_key
]):
for
current_param
,
saved_param
in
zip
(
current_group
,
saved_group
):
def
clip_grad_norm
(
self
,
clip_grad
):
current_param
.
data
.
copy_
(
saved_param
.
data
)
def
get_parameters
(
self
):
params
=
[]
params
=
[]
for
param_group
in
self
.
optimizer
.
param_groups
:
for
param_group
in
self
.
_
optimizer
.
param_groups
:
for
param
in
param_group
[
'params'
]:
for
param
in
param_group
[
'params'
]:
params
.
append
(
param
)
params
.
append
(
param
)
return
params
def
clip_grad_norm
(
self
,
clip_grad
):
params
=
self
.
get_parameters
()
return
clip_grad_norm_fp32
(
params
,
clip_grad
)
return
clip_grad_norm_fp32
(
params
,
clip_grad
)
def
count_zeros
(
self
):
params
=
self
.
get_parameters
()
return
count_zeros_fp32
(
params
)
def
scale_loss
(
self
,
loss
):
"""Simple scaling."""
return
self
.
get_loss_scale
()
*
loss
# Promote state so it can be retrieved or set via
# Promote state so it can be retrieved or set via
# "optimizer_instance.state"
# "optimizer_instance.state"
def
_get_state
(
self
):
def
_get_state
(
self
):
return
self
.
optimizer
.
state
return
self
.
_
optimizer
.
state
def
_set_state
(
self
,
value
):
def
_set_state
(
self
,
value
):
self
.
optimizer
.
state
=
value
self
.
_
optimizer
.
state
=
value
state
=
property
(
_get_state
,
_set_state
)
state
=
property
(
_get_state
,
_set_state
)
...
@@ -500,9 +405,9 @@ class FP16Optimizer(Optimizer):
...
@@ -500,9 +405,9 @@ class FP16Optimizer(Optimizer):
# "optimizer_instance.param_groups"
# "optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
# (for example, to adjust the learning rate)
def
_get_param_groups
(
self
):
def
_get_param_groups
(
self
):
return
self
.
optimizer
.
param_groups
return
self
.
_
optimizer
.
param_groups
def
_set_param_groups
(
self
,
value
):
def
_set_param_groups
(
self
,
value
):
self
.
optimizer
.
param_groups
=
value
self
.
_
optimizer
.
param_groups
=
value
param_groups
=
property
(
_get_param_groups
,
_set_param_groups
)
param_groups
=
property
(
_get_param_groups
,
_set_param_groups
)
colossalai/amp/naive_amp/_utils.py
0 → 100644
View file @
e79ea442
from
typing
import
List
from
torch
import
Tensor
def
has_inf_or_nan
(
tensor
):
try
:
# if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as tensor
# (which is true for some recent version of pytorch).
tensor_sum
=
float
(
tensor
.
float
().
sum
())
# More efficient version that can be used if .sum() returns a Python scalar
# tensor_sum = float(tensor.sum())
except
RuntimeError
as
instance
:
# We want to check if inst is actually an overflow exception.
# RuntimeError could come from a different error.
# If so, we still want the exception to propagate.
if
"value cannot be converted"
not
in
instance
.
args
[
0
]:
raise
return
True
else
:
if
tensor_sum
==
float
(
'inf'
)
or
tensor_sum
==
-
float
(
'inf'
)
or
tensor_sum
!=
tensor_sum
:
return
True
return
False
def
zero_gard_by_list
(
tensor_list
:
List
[
Tensor
],
set_to_none
:
bool
=
True
)
->
None
:
"""
Clear the gradient of a list of tensors,
Note: copied from torch.optim.optimizer.
"""
for
param
in
tensor_list
:
if
param
.
grad
is
not
None
:
if
set_to_none
:
param
.
grad
=
None
else
:
if
param
.
grad
.
grad_fn
is
not
None
:
param
.
grad
.
detach_
()
else
:
param
.
grad
.
requires_grad_
(
False
)
param
.
grad
.
zero_
()
colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
View file @
e79ea442
...
@@ -28,12 +28,10 @@ class BaseGradScaler(ABC):
...
@@ -28,12 +28,10 @@ class BaseGradScaler(ABC):
def
inv_scale
(
self
)
->
Tensor
:
def
inv_scale
(
self
)
->
Tensor
:
return
self
.
_scale
.
double
().
reciprocal
().
float
()
return
self
.
_scale
.
double
().
reciprocal
().
float
()
@
abstractmethod
def
state_dict
(
self
)
->
Dict
:
def
state_dict
(
self
)
->
Dict
:
state_dict
=
dict
()
state_dict
=
dict
()
state_dict
[
'scale'
]
=
self
.
scale
state_dict
[
'scale'
]
=
self
.
scale
@
abstractmethod
def
load_state_dict
(
self
,
state_dict
:
Dict
)
->
None
:
def
load_state_dict
(
self
,
state_dict
:
Dict
)
->
None
:
self
.
_scale
=
state_dict
[
'scale'
]
self
.
_scale
=
state_dict
[
'scale'
]
...
...
colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
View file @
e79ea442
...
@@ -16,11 +16,19 @@ class DynamicGradScaler(BaseGradScaler):
...
@@ -16,11 +16,19 @@ class DynamicGradScaler(BaseGradScaler):
growth_interval
:
int
=
1000
,
growth_interval
:
int
=
1000
,
min_scale
:
int
=
None
,
min_scale
:
int
=
None
,
max_scale
:
int
=
None
,
max_scale
:
int
=
None
,
hysteresis
:
int
=
None
,
hysteresis
:
int
=
2
,
verbose
:
bool
=
False
):
verbose
:
bool
=
False
):
super
().
__init__
(
initial_scale
,
verbose
)
super
().
__init__
(
initial_scale
,
verbose
)
self
.
_min_scale
=
min_scale
if
min_scale
:
self
.
_max_scale
=
max_scale
self
.
_min_scale
=
torch
.
cuda
.
FloatTensor
([
min_scale
])
else
:
self
.
_min_scale
=
None
if
max_scale
:
self
.
_max_scale
=
torch
.
cuda
.
FloatTensor
([
max_scale
])
else
:
self
.
_max_scale
=
None
self
.
_growth_factor
=
growth_factor
self
.
_growth_factor
=
growth_factor
self
.
_backoff_factor
=
backoff_factor
self
.
_backoff_factor
=
backoff_factor
self
.
_growth_interval
=
growth_interval
self
.
_growth_interval
=
growth_interval
...
...
colossalai/amp/naive_amp/naive_amp.py
View file @
e79ea442
...
@@ -26,17 +26,11 @@ class NaiveAMPOptimizer(ColossalaiOptimizer):
...
@@ -26,17 +26,11 @@ class NaiveAMPOptimizer(ColossalaiOptimizer):
"""
"""
def
__init__
(
self
,
optim
:
Optimizer
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
optim
:
Optimizer
,
*
args
,
**
kwargs
):
optim
=
FP16Optimizer
(
optimizer
=
optim
,
*
args
,
**
kwargs
)
optim
=
FP16Optimizer
(
optim
,
*
args
,
**
kwargs
)
super
().
__init__
(
optim
)
super
().
__init__
(
optim
)
def
backward
(
self
,
loss
:
Tensor
):
def
backward
(
self
,
loss
:
Tensor
):
"""Backward with gradient scaler
self
.
optim
.
backward
(
loss
)
:param loss: loss computed by a loss function
:type loss: torch.Tensor
"""
loss
=
self
.
optim
.
scale_loss
(
loss
)
loss
.
backward
()
def
step
(
self
):
def
step
(
self
):
return
self
.
optim
.
step
()
return
self
.
optim
.
step
()
...
...
colossalai/initialize.py
View file @
e79ea442
...
@@ -304,7 +304,7 @@ def initialize(model: nn.Module,
...
@@ -304,7 +304,7 @@ def initialize(model: nn.Module,
if
is_using_pp
():
if
is_using_pp
():
assert
amp_mode
==
AMP_TYPE
.
NAIVE
,
'Pipeline only support NaiveAMP currently'
assert
amp_mode
==
AMP_TYPE
.
NAIVE
,
'Pipeline only support NaiveAMP currently'
if
amp_mode
==
AMP_TYPE
.
NAIVE
:
if
amp_mode
==
AMP_TYPE
.
NAIVE
:
cfg_
[
'clip_grad'
]
=
clip_grad_norm
cfg_
[
'clip_grad
_norm
'
]
=
clip_grad_norm
model
,
optimizer
,
criterion
=
convert_to_amp
(
model
=
model
,
model
,
optimizer
,
criterion
=
convert_to_amp
(
model
=
model
,
optimizer
=
optimizer
,
optimizer
=
optimizer
,
criterion
=
criterion
,
criterion
=
criterion
,
...
...
colossalai/zero/sharded_optim/sharded_optim.py
View file @
e79ea442
from
itertools
import
groupby
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -7,7 +6,7 @@ from torch.optim import Optimizer
...
@@ -7,7 +6,7 @@ from torch.optim import Optimizer
from
.bookkeeping
import
ParameterStore
,
GradientStore
,
BucketStore
,
TensorBucket
from
.bookkeeping
import
ParameterStore
,
GradientStore
,
BucketStore
,
TensorBucket
from
colossalai.context
import
ParallelMode
from
colossalai.context
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.amp.naive_amp.
_fp16_optimiz
er
import
DynamicGradScaler
from
colossalai.amp.naive_amp.
grad_scal
er
import
DynamicGradScaler
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
._utils
import
(
move_tensor
,
flatten
,
get_grad_accumulate_object
,
split_half_float_double
,
reduce_tensor
,
from
._utils
import
(
move_tensor
,
flatten
,
get_grad_accumulate_object
,
split_half_float_double
,
reduce_tensor
,
release_param_grad
,
calculate_global_norm_from_list
,
compute_norm
,
sync_param
,
has_inf_or_nan
)
release_param_grad
,
calculate_global_norm_from_list
,
compute_norm
,
sync_param
,
has_inf_or_nan
)
...
@@ -16,38 +15,26 @@ from functools import partial
...
@@ -16,38 +15,26 @@ from functools import partial
class
ShardedOptimizer
(
ColossalaiOptimizer
):
class
ShardedOptimizer
(
ColossalaiOptimizer
):
def
__init__
(
def
__init__
(
self
,
self
,
optimizer
:
Optimizer
,
optimizer
:
Optimizer
,
initial_scale
=
2
**
32
,
min_scale
=
1
,
# grad scaler config
growth_factor
=
2
,
initial_scale
=
2
**
32
,
backoff_factor
=
0.5
,
min_scale
=
1
,
growth_interval
=
1000
,
growth_factor
=
2
,
hysteresis
=
2
,
backoff_factor
=
0.5
,
max_scale
:
int
=
2
**
32
,
growth_interval
=
1000
,
clip_grad_norm
=
2.0
,
hysteresis
=
2
,
verbose
=
False
,
max_scale
:
int
=
2
**
32
,
reduce_bucket_size
=
500000000
,
communication_dtype
=
torch
.
float16
,
# grad clipping
overlap_communication
=
False
,
clip_grad_norm
=
2.0
,
partition_grad
=
False
,
verbose
=
False
,
dp_parallel_mode
=
ParallelMode
.
DATA
,
mp_parallel_mode
=
ParallelMode
.
MODEL
,
# communication
cpu_offload
=
False
,
reduce_bucket_size
=
500000000
,
cpu_fp16_param
=
False
,
communication_dtype
=
torch
.
float16
,
cpu_fp16_grad
=
False
):
overlap_communication
=
False
,
# stage 2
partition_grad
=
False
,
dp_parallel_mode
=
ParallelMode
.
DATA
,
mp_parallel_mode
=
ParallelMode
.
MODEL
,
# cpu offload
cpu_offload
=
False
,
cpu_fp16_param
=
False
,
cpu_fp16_grad
=
False
):
# TODO: add support for
# TODO: add support for
# 1. fp16 master weights
# 1. fp16 master weights
...
@@ -257,12 +244,13 @@ class ShardedOptimizer(ColossalaiOptimizer):
...
@@ -257,12 +244,13 @@ class ShardedOptimizer(ColossalaiOptimizer):
reduction_func
=
partial
(
self
.
_reduce_and_remove_grads_by_bucket
,
reduction_func
=
partial
(
self
.
_reduce_and_remove_grads_by_bucket
,
param
=
param
,
param
=
param
,
reduce_rank
=
reduce_rank
)
reduce_rank
=
reduce_rank
)
# define hook
# define hook
# NOT IMPORTANT BUT GOOD TO KNOW:
# NOT IMPORTANT BUT GOOD TO KNOW:
# args here is not grad, but allow_unreacable and accumulate_grad
# args here is not grad, but allow_unreacable and accumulate_grad
def
reduce_grad_hook
(
*
args
):
def
reduce_grad_hook
(
*
args
):
reduction_func
()
reduction_func
()
accum_grad_obj
.
register_hook
(
reduce_grad_hook
)
accum_grad_obj
.
register_hook
(
reduce_grad_hook
)
_define_and_attach
(
param
,
reduce_rank
)
_define_and_attach
(
param
,
reduce_rank
)
...
@@ -293,8 +281,8 @@ class ShardedOptimizer(ColossalaiOptimizer):
...
@@ -293,8 +281,8 @@ class ShardedOptimizer(ColossalaiOptimizer):
def
_reduce_grads_in_bucket
(
self
,
reduce_rank
=
None
):
def
_reduce_grads_in_bucket
(
self
,
reduce_rank
=
None
):
# reduce grads
# reduce grads
self
.
_reduce_grads_by_rank
(
reduce_rank
=
reduce_rank
,
self
.
_reduce_grads_by_rank
(
reduce_rank
=
reduce_rank
,
grads
=
self
.
_bucket_store
.
get_grad
(
reduce_rank
=
reduce_rank
),
grads
=
self
.
_bucket_store
.
get_grad
(
reduce_rank
=
reduce_rank
),
bucket_size
=
self
.
_bucket_store
.
num_elements_in_bucket
(
reduce_rank
))
bucket_size
=
self
.
_bucket_store
.
num_elements_in_bucket
(
reduce_rank
))
# use communication stream if overlapping
# use communication stream if overlapping
# communication with computation
# communication with computation
...
@@ -323,7 +311,7 @@ class ShardedOptimizer(ColossalaiOptimizer):
...
@@ -323,7 +311,7 @@ class ShardedOptimizer(ColossalaiOptimizer):
# we do not keep the gradient after reduction
# we do not keep the gradient after reduction
if
self
.
_partition_grads
and
not
self
.
_param_store
.
belongs_to_current_rank
(
param
):
if
self
.
_partition_grads
and
not
self
.
_param_store
.
belongs_to_current_rank
(
param
):
if
self
.
_overlap_communication
:
if
self
.
_overlap_communication
:
# we need to keep this gradient for now as reduction may
# we need to keep this gradient for now as reduction may
# be completed yet since it is using a different cuda stream
# be completed yet since it is using a different cuda stream
self
.
_param_store
.
add_previous_reduced_param
(
param
)
self
.
_param_store
.
add_previous_reduced_param
(
param
)
else
:
else
:
...
@@ -444,7 +432,6 @@ class ShardedOptimizer(ColossalaiOptimizer):
...
@@ -444,7 +432,6 @@ class ShardedOptimizer(ColossalaiOptimizer):
self
.
_grad_store
.
_averaged_gradients
[
group_id
]
=
[]
self
.
_grad_store
.
_averaged_gradients
[
group_id
]
=
[]
self
.
_grad_store
.
_averaged_gradients
[
group_id
]
=
[]
self
.
_grad_store
.
_averaged_gradients
[
group_id
]
=
[]
# unscale and clip grads
# unscale and clip grads
global_norm
=
calculate_global_norm_from_list
(
norm_list
=
norm_groups
)
global_norm
=
calculate_global_norm_from_list
(
norm_list
=
norm_groups
)
self
.
_unscale_and_clip_grads
(
single_grad_partition_groups
,
global_norm
)
self
.
_unscale_and_clip_grads
(
single_grad_partition_groups
,
global_norm
)
...
@@ -501,7 +488,7 @@ class ShardedOptimizer(ColossalaiOptimizer):
...
@@ -501,7 +488,7 @@ class ShardedOptimizer(ColossalaiOptimizer):
def
_unscale_and_clip_grads
(
self
,
grad_groups_flat
,
total_norm
):
def
_unscale_and_clip_grads
(
self
,
grad_groups_flat
,
total_norm
):
# compute combined scale factor for this group
# compute combined scale factor for this group
combined_scale
=
self
.
loss_scale
combined_scale
=
self
.
loss_scale
if
self
.
_clip_grad_norm
>
0.
:
if
self
.
_clip_grad_norm
>
0.
:
# norm is in fact norm*scale
# norm is in fact norm*scale
clip
=
((
total_norm
/
self
.
loss_scale
)
+
1e-6
)
/
self
.
_clip_grad_norm
clip
=
((
total_norm
/
self
.
loss_scale
)
+
1e-6
)
/
self
.
_clip_grad_norm
...
@@ -562,7 +549,7 @@ class ShardedOptimizer(ColossalaiOptimizer):
...
@@ -562,7 +549,7 @@ class ShardedOptimizer(ColossalaiOptimizer):
for
param
in
param_group
:
for
param
in
param_group
:
if
param
.
grad
is
not
None
:
if
param
.
grad
is
not
None
:
self
.
_reduce_and_remove_grads_by_bucket
(
param
)
self
.
_reduce_and_remove_grads_by_bucket
(
param
)
# we need to reduce the gradients
# we need to reduce the gradients
# left in the communication bucket
# left in the communication bucket
self
.
_reduce_grads_in_bucket
()
self
.
_reduce_grads_in_bucket
()
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
e79ea442
...
@@ -4,7 +4,7 @@ from typing import Callable, Dict, Optional, Union
...
@@ -4,7 +4,7 @@ from typing import Callable, Dict, Optional, Union
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torch.nn
as
nn
from
colossalai.amp.naive_amp.
_fp16_optimiz
er
import
DynamicGradScaler
from
colossalai.amp.naive_amp.
grad_scal
er
import
DynamicGradScaler
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
...
...
tests/test_amp/test_naive_fp16.py
0 → 100644
View file @
e79ea442
import
torch
import
colossalai
import
copy
import
pytest
import
torch.multiprocessing
as
mp
from
colossalai.amp
import
convert_to_naive_amp
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
colossalai.utils
import
free_port
from
functools
import
partial
def
check_equal
(
a
,
b
):
"""
This function checks if two tensors are equal within tolerance
"""
assert
torch
.
allclose
(
a
.
float
(),
b
.
float
(),
rtol
=
1e-4
,
atol
=
1e-3
),
f
'a =
{
a
}
, b =
{
b
}
'
def
run_naive_amp
():
"""
In this test, we compare the naive fp16 optimizer implemented in colossalai
and fp32 torch optimizer
"""
# create layer
test_models
=
[
'repeated_computed_layers'
,
'nested_model'
]
for
test_name
in
test_models
:
get_component_func
=
non_distributed_component_funcs
.
get_callable
(
test_name
)
model_builder
,
train_dataloader
,
_
,
optim_builder
,
_
=
get_component_func
()
# create model
amp_model
=
model_builder
(
checkpoint
=
True
).
cuda
()
torch_model
=
copy
.
deepcopy
(
amp_model
)
# create optimizer
amp_optimizer
=
optim_builder
(
amp_model
)
torch_optimizer
=
optim_builder
(
torch_model
)
# inject naive amp
amp_config
=
dict
(
initial_scale
=
1
)
amp_model
,
amp_optimizer
=
convert_to_naive_amp
(
amp_model
,
amp_optimizer
,
amp_config
)
# create data
data_iter
=
iter
(
train_dataloader
)
data
,
label
=
next
(
data_iter
)
data
=
data
.
cuda
()
# forward pass
amp_output
=
amp_model
(
data
)
torch_output
=
torch_model
(
data
)
assert
torch
.
allclose
(
amp_output
,
torch_output
,
rtol
=
1e-3
,
atol
=
1e-3
),
f
'
{
amp_output
}
vs
{
torch_output
}
'
# backward
amp_optimizer
.
backward
(
amp_output
.
mean
())
torch_output
.
mean
().
backward
()
# check grad
for
amp_param
,
torch_param
in
zip
(
amp_model
.
parameters
(),
torch_model
.
parameters
()):
torch
.
allclose
(
amp_param
.
grad
,
torch_param
.
grad
.
half
(),
rtol
=
1e-3
,
atol
=
1e-3
)
# step
amp_optimizer
.
step
()
torch_optimizer
.
step
()
# check updated param
for
amp_param
,
torch_param
in
zip
(
amp_model
.
parameters
(),
torch_model
.
parameters
()):
torch
.
allclose
(
amp_param
,
torch_param
.
half
(),
rtol
=
1e-3
,
atol
=
1e-3
)
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
port
=
port
,
host
=
'localhost'
)
run_naive_amp
()
@
pytest
.
mark
.
dist
def
test_naive_amp
():
world_size
=
1
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_naive_amp
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment