Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
3d5d64bd
Commit
3d5d64bd
authored
Mar 09, 2022
by
Frank Lee
Browse files
refactored grad scaler (#338)
parent
6a318816
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
135 additions
and
0 deletions
+135
-0
colossalai/amp/naive_amp/grad_scaler/__init__.py
colossalai/amp/naive_amp/grad_scaler/__init__.py
+5
-0
colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
+46
-0
colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py
colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py
+16
-0
colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
+68
-0
No files found.
colossalai/amp/naive_amp/grad_scaler/__init__.py
0 → 100644
View file @
3d5d64bd
from
.base_grad_scaler
import
BaseGradScaler
from
.constant_grad_scaler
import
ConstantGradScaler
from
.dynamic_grad_scaler
import
DynamicGradScaler
__all__
=
[
'BaseGradScaler'
,
'ConstantGradScaler'
,
'DynamicGradScaler'
]
colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
0 → 100644
View file @
3d5d64bd
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
torch
from
abc
import
ABC
,
abstractmethod
from
colossalai.logging
import
get_dist_logger
from
torch
import
Tensor
from
typing
import
Dict
__all__
=
[
'BaseGradScaler'
]
class
BaseGradScaler
(
ABC
):
def
__init__
(
self
,
initial_scale
:
int
,
verbose
:
bool
):
assert
initial_scale
>
0
self
.
_scale
=
torch
.
cuda
.
FloatTensor
([
initial_scale
])
self
.
_verbose
=
verbose
if
self
.
_verbose
:
self
.
_logger
=
get_dist_logger
()
@
property
def
scale
(
self
)
->
Tensor
:
return
self
.
_scale
@
property
def
inv_scale
(
self
)
->
Tensor
:
return
self
.
_scale
.
double
().
reciprocal
().
float
()
@
abstractmethod
def
state_dict
(
self
)
->
Dict
:
state_dict
=
dict
()
state_dict
[
'scale'
]
=
self
.
scale
@
abstractmethod
def
load_state_dict
(
self
,
state_dict
:
Dict
)
->
None
:
self
.
_scale
=
state_dict
[
'scale'
]
@
abstractmethod
def
update
(
self
,
overflow
:
bool
)
->
None
:
pass
def
log
(
self
,
message
,
*
args
,
**
kwargs
):
if
self
.
_verbose
:
self
.
_logger
.
info
(
message
,
*
args
,
**
kwargs
)
colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py
0 → 100644
View file @
3d5d64bd
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from
.base_grad_scaler
import
BaseGradScaler
__all__
=
[
'ConstantGradScaler'
]
class
ConstantGradScaler
(
BaseGradScaler
):
def
__init__
(
self
,
initial_scale
:
int
,
verbose
:
bool
):
super
().
__init__
(
initial_scale
,
verbose
)
self
.
log
(
f
"Constant Gradient Scaler is initialized with scale
{
self
.
scale
}
"
,
ranks
=
[
0
])
def
update
(
self
,
overflow
:
bool
)
->
None
:
# do nothing to maintain the current scale value
pass
colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
0 → 100644
View file @
3d5d64bd
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
torch
from
.base_grad_scaler
import
BaseGradScaler
__all__
=
[
'DynamicGradScaler'
]
class
DynamicGradScaler
(
BaseGradScaler
):
def
__init__
(
self
,
initial_scale
:
int
=
2
**
16
,
growth_factor
:
int
=
2
,
backoff_factor
:
float
=
0.5
,
growth_interval
:
int
=
1000
,
min_scale
:
int
=
None
,
max_scale
:
int
=
None
,
hysteresis
:
int
=
None
,
verbose
:
bool
=
False
):
super
().
__init__
(
initial_scale
,
verbose
)
self
.
_min_scale
=
min_scale
self
.
_max_scale
=
max_scale
self
.
_growth_factor
=
growth_factor
self
.
_backoff_factor
=
backoff_factor
self
.
_growth_interval
=
growth_interval
self
.
_growth_step
=
0
self
.
_hysteresis
=
hysteresis
self
.
_hysteresis_step
=
0
self
.
_sanity_checks
()
def
_sanity_checks
(
self
)
->
None
:
if
self
.
_min_scale
:
assert
self
.
_min_scale
>
0
,
'The minimum gradient scale cannot be zero or negative'
if
self
.
_max_scale
:
assert
self
.
_min_scale
>
0
,
'The maximum gradient scale cannot be zero or negative'
assert
self
.
_growth_factor
>
1
,
'The growth factor cannot be equal or smaller than 1'
assert
self
.
_backoff_factor
<
1
and
self
.
_backoff_factor
>
0
,
'The backoff factor must be between 0 and 1'
assert
self
.
_hysteresis
>=
0
,
'The hysteresis cannot be negative'
def
update
(
self
,
overflow
:
bool
)
->
None
:
if
overflow
:
self
.
_hysteresis_step
+=
1
self
.
_growth_step
=
0
if
self
.
_hysteresis_step
>=
self
.
_hysteresis
:
self
.
_backoff_scale
()
self
.
log
(
f
"Overflow occurs, the loss scale is adjusted to
{
self
.
scale
.
item
()
}
"
,
ranks
=
[
0
])
else
:
self
.
_growth_step
+=
1
if
self
.
_growth_step
==
self
.
_growth_interval
:
self
.
_growth_step
=
0
self
.
_hysteresis_step
=
0
self
.
_grow_scale
()
self
.
log
(
f
"No overflow for consecutive
{
self
.
_growth_interval
}
steps, "
f
"the loss scale is adjusted to
{
self
.
scale
.
item
()
}
"
,
ranks
=
[
0
])
def
_backoff_scale
(
self
)
->
None
:
self
.
_scale
=
self
.
_scale
*
self
.
_backoff_factor
if
self
.
_min_scale
:
self
.
_scale
=
torch
.
max
(
self
.
_scale
,
self
.
_min_scale
)
def
_grow_scale
(
self
)
->
None
:
self
.
_scale
=
self
.
_scale
*
self
.
_growth_factor
if
self
.
_max_scale
:
self
.
_scale
=
torch
.
min
(
self
.
_scale
,
self
.
_max_scale
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment