Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
10d21b38
Unverified
Commit
10d21b38
authored
Jan 14, 2022
by
Anupam Bhatnagar
Committed by
GitHub
Jan 14, 2022
Browse files
small fixes to layerwise gradient scaler (#910)
parent
39e7821a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
1 deletion
+3
-1
fairscale/optim/layerwise_gradient_scaler.py
fairscale/optim/layerwise_gradient_scaler.py
+1
-1
tests/optim/test_layerwise_gradient_scaler.py
tests/optim/test_layerwise_gradient_scaler.py
+2
-0
No files found.
fairscale/optim/layerwise_gradient_scaler.py
View file @
10d21b38
...
@@ -200,7 +200,7 @@ class LayerwiseGradientScaler:
...
@@ -200,7 +200,7 @@ class LayerwiseGradientScaler:
layers_with_finite_values
=
self
.
_get_layers_with_finite_values
()
layers_with_finite_values
=
self
.
_get_layers_with_finite_values
()
for
item
in
layers_with_finite_values
:
for
item
in
layers_with_finite_values
:
for
param_name
,
param
in
item
.
layer
.
named_parameters
():
for
param_name
,
param
in
item
.
layer
.
named_parameters
():
if
hasattr
(
param
,
"grad"
):
if
hasattr
(
param
,
"grad"
)
and
param
.
grad
is
not
None
:
logging
.
debug
(
"%s scaling down %s by %s"
%
(
item
.
layer_name
,
param_name
,
1.0
/
item
.
scaling_factor
))
logging
.
debug
(
"%s scaling down %s by %s"
%
(
item
.
layer_name
,
param_name
,
1.0
/
item
.
scaling_factor
))
param
.
grad
.
mul_
(
1.0
/
item
.
scaling_factor
)
param
.
grad
.
mul_
(
1.0
/
item
.
scaling_factor
)
...
...
tests/optim/test_layerwise_gradient_scaler.py
View file @
10d21b38
import
logging
import
logging
import
os
from
typing
import
Any
,
List
,
Tuple
,
Union
from
typing
import
Any
,
List
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -200,6 +201,7 @@ def test_vision_model() -> None:
...
@@ -200,6 +201,7 @@ def test_vision_model() -> None:
# Remove randomness from various sources while testing.
# Remove randomness from various sources while testing.
torch
.
use_deterministic_algorithms
(
True
)
# type: ignore
torch
.
use_deterministic_algorithms
(
True
)
# type: ignore
# set environment variable in CircleCI for test to pass: CUBLAS_WORKSPACE_CONFIG = :4096:8
# set environment variable in CircleCI for test to pass: CUBLAS_WORKSPACE_CONFIG = :4096:8
os
.
environ
[
"CUBLAS_WORKSPACE_CONFIG"
]
=
":4096:8"
m1
=
SimpleConvNet
()
m1
=
SimpleConvNet
()
m2
=
SimpleConvNet
()
m2
=
SimpleConvNet
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment