Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a4e91bc8
Unverified
Commit
a4e91bc8
authored
Apr 12, 2022
by
Frank Lee
Committed by
GitHub
Apr 12, 2022
Browse files
[bug] fixed grad scaler compatibility with torch 1.8 (#735)
parent
53cb5848
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
2 deletions
+15
-2
colossalai/amp/torch_amp/_grad_scaler.py
colossalai/amp/torch_amp/_grad_scaler.py
+15
-2
No files found.
colossalai/amp/torch_amp/_grad_scaler.py
View file @
a4e91bc8
...
@@ -12,6 +12,7 @@ from colossalai.context import ParallelMode
...
@@ -12,6 +12,7 @@ from colossalai.context import ParallelMode
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
packaging
import
version
class
_MultiDeviceReplicator
(
object
):
class
_MultiDeviceReplicator
(
object
):
...
@@ -122,6 +123,14 @@ class GradScaler(object):
...
@@ -122,6 +123,14 @@ class GradScaler(object):
else
:
else
:
self
.
_enabled
=
enabled
self
.
_enabled
=
enabled
# check version
torch_version
=
version
.
parse
(
torch
.
__version__
)
assert
torch_version
.
major
==
1
if
torch_version
.
minor
>
8
:
self
.
_higher_than_torch18
=
True
else
:
self
.
_higher_than_torch18
=
False
if
self
.
_enabled
:
if
self
.
_enabled
:
assert
growth_factor
>
1.0
,
"The growth factor must be > 1.0."
assert
growth_factor
>
1.0
,
"The growth factor must be > 1.0."
assert
backoff_factor
<
1.0
,
"The backoff factor must be < 1.0."
assert
backoff_factor
<
1.0
,
"The backoff factor must be < 1.0."
...
@@ -404,8 +413,12 @@ class GradScaler(object):
...
@@ -404,8 +413,12 @@ class GradScaler(object):
for
i
in
range
(
1
,
len
(
found_infs
)):
for
i
in
range
(
1
,
len
(
found_infs
)):
found_inf_combined
+=
found_infs
[
i
]
found_inf_combined
+=
found_infs
[
i
]
torch
.
_amp_update_scale_
(
_scale
,
_growth_tracker
,
found_inf_combined
,
self
.
_growth_factor
,
if
self
.
_higher_than_torch18
:
self
.
_backoff_factor
,
self
.
_growth_interval
)
torch
.
_amp_update_scale_
(
_scale
,
_growth_tracker
,
found_inf_combined
,
self
.
_growth_factor
,
self
.
_backoff_factor
,
self
.
_growth_interval
)
else
:
self
.
_scale
=
torch
.
_amp_update_scale
(
_growth_tracker
,
_scale
,
found_inf_combined
,
self
.
_growth_factor
,
self
.
_backoff_factor
,
self
.
_growth_interval
)
# To prepare for next iteration, clear the data collected from optimizers this iteration.
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self
.
_per_optimizer_states
=
defaultdict
(
_refresh_per_optimizer_state
)
self
.
_per_optimizer_states
=
defaultdict
(
_refresh_per_optimizer_state
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment