Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
28097c99
Commit
28097c99
authored
Apr 18, 2019
by
ptrblck
Committed by
mcarilli
Apr 18, 2019
Browse files
initial commit, add CUDA warning to check_params_fp32 (#263)
parent
cd2708cc
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
10 deletions
+24
-10
apex/amp/_initialize.py
apex/amp/_initialize.py
+24
-10
No files found.
apex/amp/_initialize.py
View file @
28097c99
...
@@ -75,18 +75,32 @@ def check_models(models):
...
@@ -75,18 +75,32 @@ def check_models(models):
def
check_params_fp32
(
models
):
def
check_params_fp32
(
models
):
for
model
in
models
:
for
model
in
models
:
for
name
,
param
in
model
.
named_parameters
():
for
name
,
param
in
model
.
named_parameters
():
if
param
.
is_floating_point
()
and
param
.
type
()
!=
"torch.cuda.FloatTensor"
:
if
param
.
is_floating_point
():
if
'Half'
in
param
.
type
():
warn_or_err
(
"Found param {} with type {}, expected torch.cuda.FloatTensor.
\n
"
warn_or_err
(
"Found param {} with type {}, expected torch.cuda.FloatTensor.
\n
"
"When using amp.initialize, you do not need to call .half() on your model
\n
"
"When using amp.initialize, you do not need to call .half() on your model
\n
"
"before passing it, no matter what optimization level you choose."
.
format
(
"before passing it, no matter what optimization level you choose."
.
format
(
name
,
param
.
type
()))
name
,
param
.
type
()))
elif
not
param
.
is_cuda
:
warn_or_err
(
"Found param {} with type {}, expected torch.cuda.FloatTensor.
\n
"
"When using amp.initialize, you need to provide a model with parameters
\n
"
"located on a CUDA device before passing it no matter what optimization level
\n
"
"you chose. Use model.to('cuda') to use the default device."
.
format
(
name
,
param
.
type
()))
for
name
,
buf
in
model
.
named_buffers
():
for
name
,
buf
in
model
.
named_buffers
():
if
buf
.
is_floating_point
()
and
buf
.
type
()
!=
"torch.cuda.FloatTensor"
:
if
buf
.
is_floating_point
():
if
'Half'
in
buf
.
type
():
warn_or_err
(
"Found buffer {} with type {}, expected torch.cuda.FloatTensor.
\n
"
warn_or_err
(
"Found buffer {} with type {}, expected torch.cuda.FloatTensor.
\n
"
"When using amp.initialize, you do not need to call .half() on your model
\n
"
"When using amp.initialize, you do not need to call .half() on your model
\n
"
"before passing it, no matter what optimization level you choose."
.
format
(
"before passing it, no matter what optimization level you choose."
.
format
(
name
,
buf
.
type
()))
name
,
buf
.
type
()))
elif
not
buf
.
is_cuda
:
warn_or_err
(
"Found buffer {} with type {}, expected torch.cuda.FloatTensor.
\n
"
"When using amp.initialize, you need to provide a model with buffers
\n
"
"located on a CUDA device before passing it no matter what optimization level
\n
"
"you chose. Use model.to('cuda') to use the default device."
.
format
(
name
,
buf
.
type
()))
def
check_optimizers
(
optimizers
):
def
check_optimizers
(
optimizers
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment