Unverified Commit 26b4b71a authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Avoid registering FP8 scale update in ops without backward pass (#2063)



Avoid registering FP8 recipe update in ops without backward pass
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent ccbc8cf4
......@@ -176,6 +176,11 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
func_ctx.save_for_backward(*tensors_to_save)
func_ctx.tensor_objects = tensor_objects
# Whether to perform recipe update in backward pass
is_first_module = False
if fuser.first_op_requiring_backward < fuser._num_basic_ops:
is_first_module = FP8GlobalStateManager.is_first_fp8_module()
# Other context
func_ctx.backward_ops = fuser._backward_ops
func_ctx.basic_ops = fuser._basic_ops
......@@ -183,7 +188,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
func_ctx.basic_op_num_params = fuser._basic_op_num_params
func_ctx.num_extra_inputs = fuser.num_extra_inputs
func_ctx.num_extra_outputs = len(extra_outputs_flat)
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
func_ctx.is_first_module = is_first_module
# Mark output tensors as not deletable in backward
for tensor in [x] + extra_outputs_flat:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment