"vscode:/vscode.git/clone" did not exist on "2da2f1945ce5567e27d45df43acd4f65108d5c25"
Commit 3ff2178c authored by rohithkrn's avatar rohithkrn
Browse files

disble multi tensor apply for O4, O5

parent de3f3fea
...@@ -13,7 +13,7 @@ class AmpOptimizerState(object): ...@@ -13,7 +13,7 @@ class AmpOptimizerState(object):
def _master_params_to_model_params(self): def _master_params_to_model_params(self):
stash = self._amp_stash stash = self._amp_stash
if multi_tensor_applier.available and not _amp_state.opt_properties.opt_level not in {"O4", "O5"}: if multi_tensor_applier.available and _amp_state.opt_properties.opt_level not in {"O4", "O5"}:
if len(stash.all_fp16_params) > 0: if len(stash.all_fp16_params) > 0:
multi_tensor_applier( multi_tensor_applier(
stash.multi_tensor_scale, stash.multi_tensor_scale,
......
...@@ -63,7 +63,7 @@ class LossScaler(object): ...@@ -63,7 +63,7 @@ class LossScaler(object):
self._unskipped = 0 self._unskipped = 0
self._has_overflow = False self._has_overflow = False
self._overflow_buf = torch.cuda.IntTensor([0]) self._overflow_buf = torch.cuda.IntTensor([0])
if multi_tensor_applier.available: if multi_tensor_applier.available and _amp_state.opt_properties.opt_level not in {"O4", "O5"}:
import amp_C import amp_C
LossScaler.has_fused_kernel = multi_tensor_applier.available LossScaler.has_fused_kernel = multi_tensor_applier.available
LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment