Commit 1b8303d8 authored by Michael Carilli's avatar Michael Carilli
Browse files

Adding warning for amp.scale_loss

parent 7b245dba
......@@ -75,6 +75,11 @@ def scale_loss(loss,
.. _`Advanced Amp Usage`:
https://nvidia.github.io/apex/advanced.html
"""
if not hasattr(_amp_state, "opt_properties"):
raise RuntimeError("Invoked 'with amp.scale_loss`, but internal Amp state has not been initialized. "
"model, optimizer = amp.initialize(model, optimizer, opt_level=...) must be called "
"before `with amp.scale_loss`.")
if not _amp_state.opt_properties.enabled:
yield loss
return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment