Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
b0f708df
Commit
b0f708df
authored
Mar 31, 2022
by
Kai Wang (Victor Kai)
Committed by
binmakeswell
Apr 06, 2022
Browse files
fix format (#570)
parent
2a915a8b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
70 deletions
+40
-70
colossalai/amp/torch_amp/_grad_scaler.py
colossalai/amp/torch_amp/_grad_scaler.py
+40
-70
No files found.
colossalai/amp/torch_amp/_grad_scaler.py
View file @
b0f708df
...
...
@@ -27,8 +27,7 @@ class _MultiDeviceReplicator(object):
def
get
(
self
,
device
)
->
torch
.
Tensor
:
retval
=
self
.
_per_device_tensors
.
get
(
device
,
None
)
if
retval
is
None
:
retval
=
self
.
master
.
to
(
device
=
device
,
non_blocking
=
True
,
copy
=
True
)
retval
=
self
.
master
.
to
(
device
=
device
,
non_blocking
=
True
,
copy
=
True
)
self
.
_per_device_tensors
[
device
]
=
retval
return
retval
...
...
@@ -116,15 +115,9 @@ class GradScaler(object):
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
"""
def
__init__
(
self
,
init_scale
=
2.
**
16
,
growth_factor
=
2.0
,
backoff_factor
=
0.5
,
growth_interval
=
2000
,
enabled
=
True
):
def
__init__
(
self
,
init_scale
=
2.
**
16
,
growth_factor
=
2.0
,
backoff_factor
=
0.5
,
growth_interval
=
2000
,
enabled
=
True
):
if
enabled
and
not
torch
.
cuda
.
is_available
():
warnings
.
warn
(
"torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling."
)
warnings
.
warn
(
"torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling."
)
self
.
_enabled
=
False
else
:
self
.
_enabled
=
enabled
...
...
@@ -142,23 +135,18 @@ class GradScaler(object):
self
.
_init_growth_tracker
=
0
# self._growth_tracker will be lazily initialized during the first call to scale()
self
.
_growth_tracker
=
None
self
.
_per_optimizer_states
=
defaultdict
(
_refresh_per_optimizer_state
)
self
.
_per_optimizer_states
=
defaultdict
(
_refresh_per_optimizer_state
)
def
_check_scale_growth_tracker
(
self
,
funcname
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
fix
=
"This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
assert
self
.
_scale
is
not
None
,
"Attempted {} but _scale is None. "
.
format
(
funcname
)
+
fix
assert
self
.
_growth_tracker
is
not
None
,
"Attempted {} but _growth_tracker is None. "
.
format
(
funcname
)
+
fix
assert
self
.
_scale
is
not
None
,
"Attempted {} but _scale is None. "
.
format
(
funcname
)
+
fix
assert
self
.
_growth_tracker
is
not
None
,
"Attempted {} but _growth_tracker is None. "
.
format
(
funcname
)
+
fix
return
(
self
.
_scale
,
self
.
_growth_tracker
)
def
_lazy_init_scale_growth_tracker
(
self
,
dev
):
assert
self
.
_growth_tracker
is
None
,
"_growth_tracker initialized before _scale"
self
.
_scale
=
torch
.
full
(
(
1
,),
self
.
_init_scale
,
dtype
=
torch
.
float32
,
device
=
dev
)
self
.
_growth_tracker
=
torch
.
full
(
(
1
,),
self
.
_init_growth_tracker
,
dtype
=
torch
.
int32
,
device
=
dev
)
self
.
_scale
=
torch
.
full
((
1
,),
self
.
_init_scale
,
dtype
=
torch
.
float32
,
device
=
dev
)
self
.
_growth_tracker
=
torch
.
full
((
1
,),
self
.
_init_growth_tracker
,
dtype
=
torch
.
int32
,
device
=
dev
)
def
scale
(
self
,
outputs
):
"""
...
...
@@ -201,8 +189,7 @@ class GradScaler(object):
else
:
return
iterable
else
:
raise
ValueError
(
"outputs must be a Tensor or an iterable of Tensors"
)
raise
ValueError
(
"outputs must be a Tensor or an iterable of Tensors"
)
return
apply_scale
(
outputs
)
...
...
@@ -216,16 +203,14 @@ class GradScaler(object):
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# Google says mypy struggles with defaultdicts type annotations.
per_device_and_dtype_grads
=
defaultdict
(
lambda
:
defaultdict
(
list
))
# type: ignore[var-annotated]
per_device_and_dtype_grads
=
defaultdict
(
lambda
:
defaultdict
(
list
))
# type: ignore[var-annotated]
with
torch
.
no_grad
():
for
group
in
optimizer
.
param_groups
:
for
param
in
group
[
"params"
]:
if
param
.
grad
is
None
:
continue
if
(
not
allow_fp16
)
and
param
.
grad
.
dtype
==
torch
.
float16
:
raise
ValueError
(
"Attempting to unscale FP16 gradients."
)
raise
ValueError
(
"Attempting to unscale FP16 gradients."
)
if
param
.
grad
.
is_sparse
:
# is_coalesced() == False means the sparse grad has values with duplicate indices.
# coalesce() deduplicates indices and adds all values that have the same index.
...
...
@@ -238,22 +223,17 @@ class GradScaler(object):
to_unscale
=
param
.
grad
# TODO: is there a way to split by device and dtype without appending in the inner loop?
per_device_and_dtype_grads
[
to_unscale
.
device
][
to_unscale
.
dtype
].
append
(
to_unscale
)
per_device_and_dtype_grads
[
to_unscale
.
device
][
to_unscale
.
dtype
].
append
(
to_unscale
)
for
device
,
per_dtype_grads
in
per_device_and_dtype_grads
.
items
():
for
grads
in
per_dtype_grads
.
values
():
torch
.
_amp_foreach_non_finite_check_and_unscale_
(
grads
,
per_device_found_inf
.
get
(
device
),
torch
.
_amp_foreach_non_finite_check_and_unscale_
(
grads
,
per_device_found_inf
.
get
(
device
),
per_device_inv_scale
.
get
(
device
))
# For tensor parallel paramters it should be all-reduced over tensor parallel process group
if
gpc
.
is_initialized
(
ParallelMode
.
MODEL
)
and
gpc
.
get_world_size
(
ParallelMode
.
MODEL
)
>
1
:
vals
=
[
val
for
val
in
per_device_found_inf
.
_per_device_tensors
.
values
()]
coalesced
=
_flatten_dense_tensors
(
vals
)
dist
.
all_reduce
(
coalesced
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
gpc
.
get_group
(
ParallelMode
.
MODEL
))
dist
.
all_reduce
(
coalesced
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
gpc
.
get_group
(
ParallelMode
.
MODEL
))
for
buf
,
synced
in
zip
(
vals
,
_unflatten_dense_tensors
(
coalesced
,
vals
)):
buf
.
copy_
(
synced
)
return
per_device_found_inf
.
_per_device_tensors
...
...
@@ -298,19 +278,16 @@ class GradScaler(object):
optimizer_state
=
self
.
_per_optimizer_states
[
id
(
optimizer
)]
if
optimizer_state
[
"stage"
]
is
OptState
.
UNSCALED
:
raise
RuntimeError
(
"unscale_() has already been called on this optimizer since the last update()."
)
raise
RuntimeError
(
"unscale_() has already been called on this optimizer since the last update()."
)
elif
optimizer_state
[
"stage"
]
is
OptState
.
STEPPED
:
raise
RuntimeError
(
"unscale_() is being called after step()."
)
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert
self
.
_scale
is
not
None
inv_scale
=
self
.
_scale
.
double
().
reciprocal
().
float
()
found_inf
=
torch
.
full
(
(
1
,),
0.0
,
dtype
=
torch
.
float32
,
device
=
self
.
_scale
.
device
)
found_inf
=
torch
.
full
((
1
,),
0.0
,
dtype
=
torch
.
float32
,
device
=
self
.
_scale
.
device
)
optimizer_state
[
"found_inf_per_device"
]
=
self
.
_unscale_grads_
(
optimizer
,
inv_scale
,
found_inf
,
False
)
optimizer_state
[
"found_inf_per_device"
]
=
self
.
_unscale_grads_
(
optimizer
,
inv_scale
,
found_inf
,
False
)
optimizer_state
[
"stage"
]
=
OptState
.
UNSCALED
def
_maybe_opt_step
(
self
,
optimizer
,
optimizer_state
,
*
args
,
**
kwargs
):
...
...
@@ -344,16 +321,14 @@ class GradScaler(object):
return
optimizer
.
step
(
*
args
,
**
kwargs
)
if
"closure"
in
kwargs
:
raise
RuntimeError
(
"Closure use is not currently supported if GradScaler is enabled."
)
raise
RuntimeError
(
"Closure use is not currently supported if GradScaler is enabled."
)
self
.
_check_scale_growth_tracker
(
"step"
)
optimizer_state
=
self
.
_per_optimizer_states
[
id
(
optimizer
)]
if
optimizer_state
[
"stage"
]
is
OptState
.
STEPPED
:
raise
RuntimeError
(
"step() has already been called since the last update()."
)
raise
RuntimeError
(
"step() has already been called since the last update()."
)
retval
=
None
...
...
@@ -369,11 +344,9 @@ class GradScaler(object):
if
optimizer_state
[
"stage"
]
is
OptState
.
READY
:
self
.
unscale_
(
optimizer
)
assert
len
(
optimizer_state
[
"found_inf_per_device"
]
)
>
0
,
"No inf checks were recorded for this optimizer."
assert
len
(
optimizer_state
[
"found_inf_per_device"
])
>
0
,
"No inf checks were recorded for this optimizer."
retval
=
self
.
_maybe_opt_step
(
optimizer
,
optimizer_state
,
*
args
,
**
kwargs
)
retval
=
self
.
_maybe_opt_step
(
optimizer
,
optimizer_state
,
*
args
,
**
kwargs
)
optimizer_state
[
"stage"
]
=
OptState
.
STEPPED
...
...
@@ -407,35 +380,32 @@ class GradScaler(object):
if
new_scale
is
not
None
:
# Accept a new user-defined scale.
if
isinstance
(
new_scale
,
float
):
self
.
_scale
.
fill_
(
new_scale
)
# type: ignore[union-attr]
self
.
_scale
.
fill_
(
new_scale
)
# type: ignore[union-attr]
else
:
reason
=
"new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
# type: ignore[attr-defined]
assert
isinstance
(
new_scale
,
torch
.
cuda
.
FloatTensor
),
reason
assert
new_scale
.
numel
()
==
1
,
reason
assert
new_scale
.
requires_grad
is
False
,
reason
self
.
_scale
.
copy_
(
new_scale
)
# type: ignore[union-attr]
self
.
_scale
.
copy_
(
new_scale
)
# type: ignore[union-attr]
else
:
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
found_infs
=
[
found_inf
.
to
(
device
=
_scale
.
device
,
non_blocking
=
True
)
for
state
in
self
.
_per_optimizer_states
.
values
()
for
found_inf
in
state
[
"found_inf_per_device"
].
values
()]
found_infs
=
[
found_inf
.
to
(
device
=
_scale
.
device
,
non_blocking
=
True
)
for
state
in
self
.
_per_optimizer_states
.
values
()
for
found_inf
in
state
[
"found_inf_per_device"
].
values
()
]
assert
len
(
found_infs
)
>
0
,
"No inf checks were recorded prior to update."
assert
len
(
found_infs
)
>
0
,
"No inf checks were recorded prior to update."
found_inf_combined
=
found_infs
[
0
]
if
len
(
found_infs
)
>
1
:
for
i
in
range
(
1
,
len
(
found_infs
)):
found_inf_combined
+=
found_infs
[
i
]
torch
.
_amp_update_scale_
(
_scale
,
_growth_tracker
,
found_inf_combined
,
self
.
_growth_factor
,
self
.
_backoff_factor
,
self
.
_growth_interval
)
torch
.
_amp_update_scale_
(
_scale
,
_growth_tracker
,
found_inf_combined
,
self
.
_growth_factor
,
self
.
_backoff_factor
,
self
.
_growth_interval
)
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self
.
_per_optimizer_states
=
defaultdict
(
_refresh_per_optimizer_state
)
...
...
@@ -522,11 +492,13 @@ class GradScaler(object):
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
should be called after :meth:`update`.
"""
return
{
"scale"
:
self
.
get_scale
(),
"growth_factor"
:
self
.
_growth_factor
,
"backoff_factor"
:
self
.
_backoff_factor
,
"growth_interval"
:
self
.
_growth_interval
,
"_growth_tracker"
:
self
.
_get_growth_tracker
()}
if
self
.
_enabled
else
{}
return
{
"scale"
:
self
.
get_scale
(),
"growth_factor"
:
self
.
_growth_factor
,
"backoff_factor"
:
self
.
_backoff_factor
,
"growth_interval"
:
self
.
_growth_interval
,
"_growth_tracker"
:
self
.
_get_growth_tracker
()
}
if
self
.
_enabled
else
{}
def
load_state_dict
(
self
,
state_dict
):
r
"""
...
...
@@ -572,10 +544,8 @@ class GradScaler(object):
def
_check_inf_per_device
(
self
,
optimizer
):
_scale
,
_
=
self
.
_check_scale_growth_tracker
(
"_check_inf_per_device"
)
dummy_inv_scale
=
torch
.
full
(
(
1
,),
1.0
,
dtype
=
torch
.
float32
,
device
=
_scale
.
device
)
found_inf
=
torch
.
full
(
(
1
,),
0.0
,
dtype
=
torch
.
float32
,
device
=
_scale
.
device
)
dummy_inv_scale
=
torch
.
full
((
1
,),
1.0
,
dtype
=
torch
.
float32
,
device
=
_scale
.
device
)
found_inf
=
torch
.
full
((
1
,),
0.0
,
dtype
=
torch
.
float32
,
device
=
_scale
.
device
)
self
.
_per_optimizer_states
[
id
(
optimizer
)][
"found_inf_per_device"
]
=
\
self
.
_unscale_grads_
(
optimizer
,
dummy_inv_scale
,
found_inf
,
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment