Commit 3ff2178c authored by rohithkrn's avatar rohithkrn
Browse files

disble multi tensor apply for O4, O5

parent de3f3fea
......@@ -13,7 +13,7 @@ class AmpOptimizerState(object):
def _master_params_to_model_params(self):
stash = self._amp_stash
if multi_tensor_applier.available and not _amp_state.opt_properties.opt_level not in {"O4", "O5"}:
if multi_tensor_applier.available and _amp_state.opt_properties.opt_level not in {"O4", "O5"}:
if len(stash.all_fp16_params) > 0:
multi_tensor_applier(
stash.multi_tensor_scale,
......
......@@ -63,7 +63,7 @@ class LossScaler(object):
self._unskipped = 0
self._has_overflow = False
self._overflow_buf = torch.cuda.IntTensor([0])
if multi_tensor_applier.available:
if multi_tensor_applier.available and _amp_state.opt_properties.opt_level not in {"O4", "O5"}:
import amp_C
LossScaler.has_fused_kernel = multi_tensor_applier.available
LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment