Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
b0f708df
Commit
b0f708df
authored
Mar 31, 2022
by
Kai Wang (Victor Kai)
Committed by
binmakeswell
Apr 06, 2022
Browse files
fix format (#570)
parent
2a915a8b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
70 deletions
+40
-70
colossalai/amp/torch_amp/_grad_scaler.py
colossalai/amp/torch_amp/_grad_scaler.py
+40
-70
No files found.
colossalai/amp/torch_amp/_grad_scaler.py
View file @
b0f708df
...
@@ -27,8 +27,7 @@ class _MultiDeviceReplicator(object):
...
@@ -27,8 +27,7 @@ class _MultiDeviceReplicator(object):
def
get
(
self
,
device
)
->
torch
.
Tensor
:
def
get
(
self
,
device
)
->
torch
.
Tensor
:
retval
=
self
.
_per_device_tensors
.
get
(
device
,
None
)
retval
=
self
.
_per_device_tensors
.
get
(
device
,
None
)
if
retval
is
None
:
if
retval
is
None
:
retval
=
self
.
master
.
to
(
retval
=
self
.
master
.
to
(
device
=
device
,
non_blocking
=
True
,
copy
=
True
)
device
=
device
,
non_blocking
=
True
,
copy
=
True
)
self
.
_per_device_tensors
[
device
]
=
retval
self
.
_per_device_tensors
[
device
]
=
retval
return
retval
return
retval
...
@@ -116,15 +115,9 @@ class GradScaler(object):
...
@@ -116,15 +115,9 @@ class GradScaler(object):
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
init_scale
=
2.
**
16
,
growth_factor
=
2.0
,
backoff_factor
=
0.5
,
growth_interval
=
2000
,
enabled
=
True
):
init_scale
=
2.
**
16
,
growth_factor
=
2.0
,
backoff_factor
=
0.5
,
growth_interval
=
2000
,
enabled
=
True
):
if
enabled
and
not
torch
.
cuda
.
is_available
():
if
enabled
and
not
torch
.
cuda
.
is_available
():
warnings
.
warn
(
warnings
.
warn
(
"torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling."
)
"torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling."
)
self
.
_enabled
=
False
self
.
_enabled
=
False
else
:
else
:
self
.
_enabled
=
enabled
self
.
_enabled
=
enabled
...
@@ -142,23 +135,18 @@ class GradScaler(object):
...
@@ -142,23 +135,18 @@ class GradScaler(object):
self
.
_init_growth_tracker
=
0
self
.
_init_growth_tracker
=
0
# self._growth_tracker will be lazily initialized during the first call to scale()
# self._growth_tracker will be lazily initialized during the first call to scale()
self
.
_growth_tracker
=
None
self
.
_growth_tracker
=
None
self
.
_per_optimizer_states
=
defaultdict
(
self
.
_per_optimizer_states
=
defaultdict
(
_refresh_per_optimizer_state
)
_refresh_per_optimizer_state
)
def
_check_scale_growth_tracker
(
self
,
funcname
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
_check_scale_growth_tracker
(
self
,
funcname
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
fix
=
"This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
fix
=
"This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
assert
self
.
_scale
is
not
None
,
"Attempted {} but _scale is None. "
.
format
(
assert
self
.
_scale
is
not
None
,
"Attempted {} but _scale is None. "
.
format
(
funcname
)
+
fix
funcname
)
+
fix
assert
self
.
_growth_tracker
is
not
None
,
"Attempted {} but _growth_tracker is None. "
.
format
(
funcname
)
+
fix
assert
self
.
_growth_tracker
is
not
None
,
"Attempted {} but _growth_tracker is None. "
.
format
(
funcname
)
+
fix
return
(
self
.
_scale
,
self
.
_growth_tracker
)
return
(
self
.
_scale
,
self
.
_growth_tracker
)
def
_lazy_init_scale_growth_tracker
(
self
,
dev
):
def
_lazy_init_scale_growth_tracker
(
self
,
dev
):
assert
self
.
_growth_tracker
is
None
,
"_growth_tracker initialized before _scale"
assert
self
.
_growth_tracker
is
None
,
"_growth_tracker initialized before _scale"
self
.
_scale
=
torch
.
full
(
self
.
_scale
=
torch
.
full
((
1
,),
self
.
_init_scale
,
dtype
=
torch
.
float32
,
device
=
dev
)
(
1
,),
self
.
_init_scale
,
dtype
=
torch
.
float32
,
device
=
dev
)
self
.
_growth_tracker
=
torch
.
full
((
1
,),
self
.
_init_growth_tracker
,
dtype
=
torch
.
int32
,
device
=
dev
)
self
.
_growth_tracker
=
torch
.
full
(
(
1
,),
self
.
_init_growth_tracker
,
dtype
=
torch
.
int32
,
device
=
dev
)
def
scale
(
self
,
outputs
):
def
scale
(
self
,
outputs
):
"""
"""
...
@@ -201,8 +189,7 @@ class GradScaler(object):
...
@@ -201,8 +189,7 @@ class GradScaler(object):
else
:
else
:
return
iterable
return
iterable
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"outputs must be a Tensor or an iterable of Tensors"
)
"outputs must be a Tensor or an iterable of Tensors"
)
return
apply_scale
(
outputs
)
return
apply_scale
(
outputs
)
...
@@ -216,16 +203,14 @@ class GradScaler(object):
...
@@ -216,16 +203,14 @@ class GradScaler(object):
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# Google says mypy struggles with defaultdicts type annotations.
# Google says mypy struggles with defaultdicts type annotations.
per_device_and_dtype_grads
=
defaultdict
(
per_device_and_dtype_grads
=
defaultdict
(
lambda
:
defaultdict
(
list
))
# type: ignore[var-annotated]
lambda
:
defaultdict
(
list
))
# type: ignore[var-annotated]
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
group
in
optimizer
.
param_groups
:
for
group
in
optimizer
.
param_groups
:
for
param
in
group
[
"params"
]:
for
param
in
group
[
"params"
]:
if
param
.
grad
is
None
:
if
param
.
grad
is
None
:
continue
continue
if
(
not
allow_fp16
)
and
param
.
grad
.
dtype
==
torch
.
float16
:
if
(
not
allow_fp16
)
and
param
.
grad
.
dtype
==
torch
.
float16
:
raise
ValueError
(
raise
ValueError
(
"Attempting to unscale FP16 gradients."
)
"Attempting to unscale FP16 gradients."
)
if
param
.
grad
.
is_sparse
:
if
param
.
grad
.
is_sparse
:
# is_coalesced() == False means the sparse grad has values with duplicate indices.
# is_coalesced() == False means the sparse grad has values with duplicate indices.
# coalesce() deduplicates indices and adds all values that have the same index.
# coalesce() deduplicates indices and adds all values that have the same index.
...
@@ -238,22 +223,17 @@ class GradScaler(object):
...
@@ -238,22 +223,17 @@ class GradScaler(object):
to_unscale
=
param
.
grad
to_unscale
=
param
.
grad
# TODO: is there a way to split by device and dtype without appending in the inner loop?
# TODO: is there a way to split by device and dtype without appending in the inner loop?
per_device_and_dtype_grads
[
to_unscale
.
device
][
to_unscale
.
dtype
].
append
(
per_device_and_dtype_grads
[
to_unscale
.
device
][
to_unscale
.
dtype
].
append
(
to_unscale
)
to_unscale
)
for
device
,
per_dtype_grads
in
per_device_and_dtype_grads
.
items
():
for
device
,
per_dtype_grads
in
per_device_and_dtype_grads
.
items
():
for
grads
in
per_dtype_grads
.
values
():
for
grads
in
per_dtype_grads
.
values
():
torch
.
_amp_foreach_non_finite_check_and_unscale_
(
grads
,
torch
.
_amp_foreach_non_finite_check_and_unscale_
(
grads
,
per_device_found_inf
.
get
(
device
),
per_device_found_inf
.
get
(
device
),
per_device_inv_scale
.
get
(
device
))
per_device_inv_scale
.
get
(
device
))
# For tensor parallel paramters it should be all-reduced over tensor parallel process group
# For tensor parallel paramters it should be all-reduced over tensor parallel process group
if
gpc
.
is_initialized
(
ParallelMode
.
MODEL
)
and
gpc
.
get_world_size
(
ParallelMode
.
MODEL
)
>
1
:
if
gpc
.
is_initialized
(
ParallelMode
.
MODEL
)
and
gpc
.
get_world_size
(
ParallelMode
.
MODEL
)
>
1
:
vals
=
[
val
for
val
in
per_device_found_inf
.
_per_device_tensors
.
values
()]
vals
=
[
val
for
val
in
per_device_found_inf
.
_per_device_tensors
.
values
()]
coalesced
=
_flatten_dense_tensors
(
vals
)
coalesced
=
_flatten_dense_tensors
(
vals
)
dist
.
all_reduce
(
coalesced
,
dist
.
all_reduce
(
coalesced
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
gpc
.
get_group
(
ParallelMode
.
MODEL
))
op
=
dist
.
ReduceOp
.
MAX
,
group
=
gpc
.
get_group
(
ParallelMode
.
MODEL
))
for
buf
,
synced
in
zip
(
vals
,
_unflatten_dense_tensors
(
coalesced
,
vals
)):
for
buf
,
synced
in
zip
(
vals
,
_unflatten_dense_tensors
(
coalesced
,
vals
)):
buf
.
copy_
(
synced
)
buf
.
copy_
(
synced
)
return
per_device_found_inf
.
_per_device_tensors
return
per_device_found_inf
.
_per_device_tensors
...
@@ -298,19 +278,16 @@ class GradScaler(object):
...
@@ -298,19 +278,16 @@ class GradScaler(object):
optimizer_state
=
self
.
_per_optimizer_states
[
id
(
optimizer
)]
optimizer_state
=
self
.
_per_optimizer_states
[
id
(
optimizer
)]
if
optimizer_state
[
"stage"
]
is
OptState
.
UNSCALED
:
if
optimizer_state
[
"stage"
]
is
OptState
.
UNSCALED
:
raise
RuntimeError
(
raise
RuntimeError
(
"unscale_() has already been called on this optimizer since the last update()."
)
"unscale_() has already been called on this optimizer since the last update()."
)
elif
optimizer_state
[
"stage"
]
is
OptState
.
STEPPED
:
elif
optimizer_state
[
"stage"
]
is
OptState
.
STEPPED
:
raise
RuntimeError
(
"unscale_() is being called after step()."
)
raise
RuntimeError
(
"unscale_() is being called after step()."
)
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert
self
.
_scale
is
not
None
assert
self
.
_scale
is
not
None
inv_scale
=
self
.
_scale
.
double
().
reciprocal
().
float
()
inv_scale
=
self
.
_scale
.
double
().
reciprocal
().
float
()
found_inf
=
torch
.
full
(
found_inf
=
torch
.
full
((
1
,),
0.0
,
dtype
=
torch
.
float32
,
device
=
self
.
_scale
.
device
)
(
1
,),
0.0
,
dtype
=
torch
.
float32
,
device
=
self
.
_scale
.
device
)
optimizer_state
[
"found_inf_per_device"
]
=
self
.
_unscale_grads_
(
optimizer_state
[
"found_inf_per_device"
]
=
self
.
_unscale_grads_
(
optimizer
,
inv_scale
,
found_inf
,
False
)
optimizer
,
inv_scale
,
found_inf
,
False
)
optimizer_state
[
"stage"
]
=
OptState
.
UNSCALED
optimizer_state
[
"stage"
]
=
OptState
.
UNSCALED
def
_maybe_opt_step
(
self
,
optimizer
,
optimizer_state
,
*
args
,
**
kwargs
):
def
_maybe_opt_step
(
self
,
optimizer
,
optimizer_state
,
*
args
,
**
kwargs
):
...
@@ -344,16 +321,14 @@ class GradScaler(object):
...
@@ -344,16 +321,14 @@ class GradScaler(object):
return
optimizer
.
step
(
*
args
,
**
kwargs
)
return
optimizer
.
step
(
*
args
,
**
kwargs
)
if
"closure"
in
kwargs
:
if
"closure"
in
kwargs
:
raise
RuntimeError
(
raise
RuntimeError
(
"Closure use is not currently supported if GradScaler is enabled."
)
"Closure use is not currently supported if GradScaler is enabled."
)
self
.
_check_scale_growth_tracker
(
"step"
)
self
.
_check_scale_growth_tracker
(
"step"
)
optimizer_state
=
self
.
_per_optimizer_states
[
id
(
optimizer
)]
optimizer_state
=
self
.
_per_optimizer_states
[
id
(
optimizer
)]
if
optimizer_state
[
"stage"
]
is
OptState
.
STEPPED
:
if
optimizer_state
[
"stage"
]
is
OptState
.
STEPPED
:
raise
RuntimeError
(
raise
RuntimeError
(
"step() has already been called since the last update()."
)
"step() has already been called since the last update()."
)
retval
=
None
retval
=
None
...
@@ -369,11 +344,9 @@ class GradScaler(object):
...
@@ -369,11 +344,9 @@ class GradScaler(object):
if
optimizer_state
[
"stage"
]
is
OptState
.
READY
:
if
optimizer_state
[
"stage"
]
is
OptState
.
READY
:
self
.
unscale_
(
optimizer
)
self
.
unscale_
(
optimizer
)
assert
len
(
optimizer_state
[
"found_inf_per_device"
]
assert
len
(
optimizer_state
[
"found_inf_per_device"
])
>
0
,
"No inf checks were recorded for this optimizer."
)
>
0
,
"No inf checks were recorded for this optimizer."
retval
=
self
.
_maybe_opt_step
(
retval
=
self
.
_maybe_opt_step
(
optimizer
,
optimizer_state
,
*
args
,
**
kwargs
)
optimizer
,
optimizer_state
,
*
args
,
**
kwargs
)
optimizer_state
[
"stage"
]
=
OptState
.
STEPPED
optimizer_state
[
"stage"
]
=
OptState
.
STEPPED
...
@@ -418,24 +391,21 @@ class GradScaler(object):
...
@@ -418,24 +391,21 @@ class GradScaler(object):
else
:
else
:
# Consume shared inf/nan data collected from optimizers to update the scale.
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
found_infs
=
[
found_inf
.
to
(
device
=
_scale
.
device
,
non_blocking
=
True
)
found_infs
=
[
found_inf
.
to
(
device
=
_scale
.
device
,
non_blocking
=
True
)
for
state
in
self
.
_per_optimizer_states
.
values
()
for
state
in
self
.
_per_optimizer_states
.
values
()
for
found_inf
in
state
[
"found_inf_per_device"
].
values
()]
for
found_inf
in
state
[
"found_inf_per_device"
].
values
()
]
assert
len
(
assert
len
(
found_infs
)
>
0
,
"No inf checks were recorded prior to update."
found_infs
)
>
0
,
"No inf checks were recorded prior to update."
found_inf_combined
=
found_infs
[
0
]
found_inf_combined
=
found_infs
[
0
]
if
len
(
found_infs
)
>
1
:
if
len
(
found_infs
)
>
1
:
for
i
in
range
(
1
,
len
(
found_infs
)):
for
i
in
range
(
1
,
len
(
found_infs
)):
found_inf_combined
+=
found_infs
[
i
]
found_inf_combined
+=
found_infs
[
i
]
torch
.
_amp_update_scale_
(
_scale
,
torch
.
_amp_update_scale_
(
_scale
,
_growth_tracker
,
found_inf_combined
,
self
.
_growth_factor
,
_growth_tracker
,
self
.
_backoff_factor
,
self
.
_growth_interval
)
found_inf_combined
,
self
.
_growth_factor
,
self
.
_backoff_factor
,
self
.
_growth_interval
)
# To prepare for next iteration, clear the data collected from optimizers this iteration.
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self
.
_per_optimizer_states
=
defaultdict
(
_refresh_per_optimizer_state
)
self
.
_per_optimizer_states
=
defaultdict
(
_refresh_per_optimizer_state
)
...
@@ -522,11 +492,13 @@ class GradScaler(object):
...
@@ -522,11 +492,13 @@ class GradScaler(object):
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
should be called after :meth:`update`.
should be called after :meth:`update`.
"""
"""
return
{
"scale"
:
self
.
get_scale
(),
return
{
"scale"
:
self
.
get_scale
(),
"growth_factor"
:
self
.
_growth_factor
,
"growth_factor"
:
self
.
_growth_factor
,
"backoff_factor"
:
self
.
_backoff_factor
,
"backoff_factor"
:
self
.
_backoff_factor
,
"growth_interval"
:
self
.
_growth_interval
,
"growth_interval"
:
self
.
_growth_interval
,
"_growth_tracker"
:
self
.
_get_growth_tracker
()}
if
self
.
_enabled
else
{}
"_growth_tracker"
:
self
.
_get_growth_tracker
()
}
if
self
.
_enabled
else
{}
def
load_state_dict
(
self
,
state_dict
):
def
load_state_dict
(
self
,
state_dict
):
r
"""
r
"""
...
@@ -572,10 +544,8 @@ class GradScaler(object):
...
@@ -572,10 +544,8 @@ class GradScaler(object):
def
_check_inf_per_device
(
self
,
optimizer
):
def
_check_inf_per_device
(
self
,
optimizer
):
_scale
,
_
=
self
.
_check_scale_growth_tracker
(
"_check_inf_per_device"
)
_scale
,
_
=
self
.
_check_scale_growth_tracker
(
"_check_inf_per_device"
)
dummy_inv_scale
=
torch
.
full
(
dummy_inv_scale
=
torch
.
full
((
1
,),
1.0
,
dtype
=
torch
.
float32
,
device
=
_scale
.
device
)
(
1
,),
1.0
,
dtype
=
torch
.
float32
,
device
=
_scale
.
device
)
found_inf
=
torch
.
full
((
1
,),
0.0
,
dtype
=
torch
.
float32
,
device
=
_scale
.
device
)
found_inf
=
torch
.
full
(
(
1
,),
0.0
,
dtype
=
torch
.
float32
,
device
=
_scale
.
device
)
self
.
_per_optimizer_states
[
id
(
optimizer
)][
"found_inf_per_device"
]
=
\
self
.
_per_optimizer_states
[
id
(
optimizer
)][
"found_inf_per_device"
]
=
\
self
.
_unscale_grads_
(
optimizer
,
dummy_inv_scale
,
found_inf
,
True
)
self
.
_unscale_grads_
(
optimizer
,
dummy_inv_scale
,
found_inf
,
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment