Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9fdebadd
"vscode:/vscode.git/clone" did not exist on "2a0558d8ec11fdc403fc3487e1819b0b19cd20a0"
Unverified
Commit
9fdebadd
authored
Apr 25, 2022
by
Frank Lee
Committed by
GitHub
Apr 25, 2022
Browse files
[doc] improved docstring in the amp module (#857)
parent
b862d89d
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
147 additions
and
10 deletions
+147
-10
colossalai/amp/apex_amp/__init__.py
colossalai/amp/apex_amp/__init__.py
+3
-3
colossalai/amp/apex_amp/apex_amp.py
colossalai/amp/apex_amp/apex_amp.py
+1
-1
colossalai/amp/naive_amp/__init__.py
colossalai/amp/naive_amp/__init__.py
+2
-3
colossalai/amp/naive_amp/_fp16_optimizer.py
colossalai/amp/naive_amp/_fp16_optimizer.py
+50
-0
colossalai/amp/naive_amp/_utils.py
colossalai/amp/naive_amp/_utils.py
+10
-2
colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
+35
-0
colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py
colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py
+11
-1
colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
+26
-0
colossalai/amp/torch_amp/torch_amp.py
colossalai/amp/torch_amp/torch_amp.py
+9
-0
No files found.
colossalai/amp/apex_amp/__init__.py
View file @
9fdebadd
...
...
@@ -11,6 +11,9 @@ def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config):
optimizer (:class:`torch.optim.Optimizer`): your optimizer object.
amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for initializing apex_amp.
Returns:
Tuple: A tuple (model, optimizer).
The ``amp_config`` should include parameters below:
::
...
...
@@ -27,9 +30,6 @@ def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config):
min_loss_scale (float, default=None)
max_loss_scale (float, default=2.**24)
Returns:
Tuples: A tuple (model, optimizer).
More details about ``amp_config`` refer to `amp_config <https://nvidia.github.io/apex/amp.html?highlight=apex%20amp>`_.
"""
import
apex.amp
as
apex_amp
...
...
colossalai/amp/apex_amp/apex_amp.py
View file @
9fdebadd
...
...
@@ -28,7 +28,7 @@ class ApexAMPOptimizer(ColossalaiOptimizer):
scaled_loss
.
backward
()
def
clip_grad_norm
(
self
,
model
:
nn
.
Module
,
max_norm
:
float
):
"""Clip gradients
'
norm
"""Clip gradients
by
norm
Args:
model (torch.nn.Module): Your model object
...
...
colossalai/amp/naive_amp/__init__.py
View file @
9fdebadd
...
...
@@ -17,6 +17,8 @@ def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
optimizer (:class:`torch.optim.Optimizer`): your optimizer object
amp_config (:class:`colossalai.context.Config` or dict): configuration for naive mode amp.
Returns:
Tuple: A tuple (model, optimizer)
The ``amp_config`` should contain parameters below::
...
...
@@ -24,9 +26,6 @@ def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0).
Note that clipping is ignored if clip_grad == 0.
dynamic_grad_scale (bool): whether to use dynamic grad scaler.
Returns:
Tuples: A tuple (model, optimizer)
"""
if
isinstance
(
model
,
nn
.
ModuleList
):
# interleaved pipeline
...
...
colossalai/amp/naive_amp/_fp16_optimizer.py
View file @
9fdebadd
...
...
@@ -152,18 +152,39 @@ class FP16Optimizer(Optimizer):
@
property
def
grad_scaler
(
self
):
"""Returns the gradient scaler.
Returns:
:class:`BaseGradScaler`: gradient scaler.
"""
return
self
.
_grad_scaler
@
property
def
loss_scale
(
self
):
"""Returns the loss scale.
Returns:
int: loss scale.
"""
return
self
.
_grad_scaler
.
scale
@
property
def
optimizer
(
self
):
"""Returns the optimizer.
Returns:
:class:`torch.optim.Optimizer`: the optimizer object wrapped.
"""
return
self
.
_optimizer
@
property
def
defaults
(
self
):
"""Returns the default arguments of optimizer.
Returns:
dict: optimizer arguments saved in defaults of the optimizer wrapped.
"""
return
self
.
_defaults
def
_check_overflow
(
self
):
...
...
@@ -188,6 +209,12 @@ class FP16Optimizer(Optimizer):
return
self
.
_found_overflow
.
item
()
>
0
def
zero_grad
(
self
,
set_to_none
=
True
):
"""Set gradient to zero.
Args:
set_to_none (bool): Whether set the gradient to None.
"""
# set_to_none = True can save some memory space
for
param_group
in
self
.
_optimizer
.
param_groups
:
zero_gard_by_list
(
param_group
[
'params'
],
set_to_none
=
set_to_none
)
...
...
@@ -222,6 +249,9 @@ class FP16Optimizer(Optimizer):
overflow_buf
=
self
.
_dummy_overflow_buf
)
def
step
(
self
):
"""Update the model parameters.
"""
# Copy gradients from model params to main params.
self
.
_assign_grad_to_fp32_master_param
()
self
.
_unscale_grads
()
...
...
@@ -248,10 +278,19 @@ class FP16Optimizer(Optimizer):
return
True
,
grad_norm
def
backward
(
self
,
loss
):
"""Execute backward pass.
Args:
loss (:class:`torch.Tensor`): the loss value.
"""
scaled_loss
=
loss
*
self
.
grad_scaler
.
scale
scaled_loss
.
backward
()
def
state_dict
(
self
):
"""Returns the states of the fp16 optimizer as a dict object.
"""
state_dict
=
{}
state_dict
[
'optimizer'
]
=
self
.
_optimizer
.
state_dict
()
if
self
.
grad_scaler
:
...
...
@@ -260,6 +299,12 @@ class FP16Optimizer(Optimizer):
return
state_dict
def
load_state_dict
(
self
,
state_dict
):
"""Load the states of the fp16 optimizer from a dict object.
Args:
state_dict (dict): the states of the fp16 optimizer
"""
# Optimizer.
self
.
_optimizer
.
load_state_dict
(
state_dict
[
'optimizer'
])
...
...
@@ -275,6 +320,11 @@ class FP16Optimizer(Optimizer):
current_param
.
data
.
copy_
(
ckpt_param
.
data
)
def
clip_grad_norm
(
self
,
clip_grad
):
"""Clip gradients by norm.
Args:
clip_grad (float): the max norm for clipping
"""
params
=
[]
for
param_group
in
self
.
_optimizer
.
param_groups
:
for
param
in
param_group
[
'params'
]:
...
...
colossalai/amp/naive_amp/_utils.py
View file @
9fdebadd
...
...
@@ -3,6 +3,14 @@ from torch import Tensor
def
has_inf_or_nan
(
tensor
):
"""Check if tensor has inf or nan values.
Args:
tensor (:class:`torch.Tensor`): a torch tensor object
Returns:
bool: Whether the tensor has inf or nan. True for yes and False for no.
"""
try
:
# if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as tensor
...
...
@@ -24,8 +32,8 @@ def has_inf_or_nan(tensor):
def
zero_gard_by_list
(
tensor_list
:
List
[
Tensor
],
set_to_none
:
bool
=
True
)
->
None
:
"""
Clear the gradient of a list of tensors,
"""
Clear the gradient of a list of tensors,
Note: copied from torch.optim.optimizer.
"""
for
param
in
tensor_list
:
...
...
colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
View file @
9fdebadd
...
...
@@ -11,6 +11,12 @@ __all__ = ['BaseGradScaler']
class
BaseGradScaler
(
ABC
):
"""A base class for the gradient scaler.
Args:
initial_scale (float): the initial loss scale
verbose (bool): whether to log messages
"""
def
__init__
(
self
,
initial_scale
:
float
,
verbose
:
bool
):
assert
initial_scale
>
0
...
...
@@ -22,24 +28,53 @@ class BaseGradScaler(ABC):
@
property
def
scale
(
self
)
->
Tensor
:
"""Returns the loss scale.
"""
return
self
.
_scale
@
property
def
inv_scale
(
self
)
->
Tensor
:
"""Returns the inverse of the loss scale.
"""
return
self
.
_scale
.
double
().
reciprocal
().
float
()
def
state_dict
(
self
)
->
Dict
:
"""Returns the states of the gradient scaler as a dict object.
"""
state_dict
=
dict
()
state_dict
[
'scale'
]
=
self
.
scale
return
state_dict
def
load_state_dict
(
self
,
state_dict
:
Dict
)
->
None
:
"""Load the states of the gradient scaler from a dict object.
Args:
state_dict (dict): the states of the gradient scaler
"""
self
.
_scale
=
state_dict
[
'scale'
]
@
abstractmethod
def
update
(
self
,
overflow
:
bool
)
->
None
:
"""Update the loss scale.
Args:
overflow (bool): whether overflow occurs
"""
pass
def
log
(
self
,
message
,
*
args
,
**
kwargs
):
"""Log messages.
Args:
message (str): the message to log
*args: positional arguments for :class:`colossalai.logging.DistributedLogger`
**kwargs: key-word arguments for :class:`colossalai.logging.DistributedLogger`
"""
if
self
.
_verbose
:
self
.
_logger
.
info
(
message
,
*
args
,
**
kwargs
)
colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py
View file @
9fdebadd
...
...
@@ -6,11 +6,21 @@ __all__ = ['ConstantGradScaler']
class
ConstantGradScaler
(
BaseGradScaler
):
"""A gradient scaler which uses constant loss scale
Args:
initial_scale (float): the initial loss scale
verbose (bool): whether to log messages
"""
def
__init__
(
self
,
initial_scale
:
int
,
verbose
:
bool
):
super
().
__init__
(
initial_scale
,
verbose
)
self
.
log
(
f
"Constant Gradient Scaler is initialized with scale
{
self
.
scale
}
"
,
ranks
=
[
0
])
def
update
(
self
,
overflow
:
bool
)
->
None
:
# do nothing to maintain the current scale value
"""Do nothing to keep the loss scale constant.
Args:
overflow (bool): whether overflow occurs
"""
pass
colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
View file @
9fdebadd
...
...
@@ -9,6 +9,18 @@ __all__ = ['DynamicGradScaler']
class
DynamicGradScaler
(
BaseGradScaler
):
"""A gradient scaler which uses dynamic loss scale
Args:
initial_scale (float): the initial loss scale, defaults to 2**16
growth_factor (float): the multiplication factor for increasing loss scale, defaults to 2
backoff_factor (float): the multiplication factor for decreasing loss scale, defaults to 0.5
growth_interval (int): the number of steps to increase loss scale when no overflow occurs, defaults to 1000
min_scale (float): the minimum loss scale, defaults to None
max_scale (float): the maximum loss scale, defaults to None
hysteresis (int): the number of overflows before decreasing loss scale, defaults to 2
verbose (bool): whether to log messages, defaults to False
"""
def
__init__
(
self
,
initial_scale
:
float
=
2
**
16
,
...
...
@@ -39,6 +51,9 @@ class DynamicGradScaler(BaseGradScaler):
self
.
_sanity_checks
()
def
_sanity_checks
(
self
)
->
None
:
"""Check if the arguments are correct.
"""
if
self
.
_min_scale
:
assert
self
.
_min_scale
>
0
,
'The minimum gradient scale cannot be zero or negative'
if
self
.
_max_scale
:
...
...
@@ -48,6 +63,11 @@ class DynamicGradScaler(BaseGradScaler):
assert
self
.
_hysteresis
>=
0
,
'The hysteresis cannot be negative'
def
update
(
self
,
overflow
:
bool
)
->
None
:
"""Update the loss scale.
Args:
overflow (bool): whether overflow occurs
"""
if
overflow
:
self
.
_hysteresis_step
+=
1
self
.
_growth_step
=
0
...
...
@@ -67,11 +87,17 @@ class DynamicGradScaler(BaseGradScaler):
ranks
=
[
0
])
def
_backoff_scale
(
self
)
->
None
:
"""Decrease the loss scale
"""
self
.
_scale
=
self
.
_scale
*
self
.
_backoff_factor
if
self
.
_min_scale
:
self
.
_scale
=
torch
.
max
(
self
.
_scale
,
self
.
_min_scale
)
def
_grow_scale
(
self
)
->
None
:
"""Increase the loss scale
"""
self
.
_scale
=
self
.
_scale
*
self
.
_growth_factor
if
self
.
_max_scale
:
self
.
_scale
=
torch
.
min
(
self
.
_scale
,
self
.
_max_scale
)
colossalai/amp/torch_amp/torch_amp.py
View file @
9fdebadd
...
...
@@ -62,6 +62,9 @@ class TorchAMPOptimizer(ColossalaiOptimizer):
class
TorchAMPModel
(
nn
.
Module
):
"""A wrapper class for a model object which executes forward with values automatically
cast to fp16
Args:
model (:class:`torch.nn.Module`): a torch model instance
"""
def
__init__
(
self
,
model
:
nn
.
Module
)
->
None
:
...
...
@@ -70,6 +73,9 @@ class TorchAMPModel(nn.Module):
@
torch_amp
.
autocast
()
def
forward
(
self
,
*
args
,
**
kwargs
):
"""
Execute forward under the torch amp context
"""
return
self
.
model
(
*
args
,
**
kwargs
)
...
...
@@ -86,4 +92,7 @@ class TorchAMPLoss(nn.Module):
@
torch_amp
.
autocast
()
def
forward
(
self
,
*
args
,
**
kwargs
):
"""
Execute forward under the torch amp context
"""
return
self
.
loss
(
*
args
,
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment