Unverified Commit 3d7ff1c6 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Avoid `parameters` function in op backward pass (#1403)



* Avoid `parameters` function in op backward pass
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 7aa81186
...@@ -192,7 +192,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -192,7 +192,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
func_ctx.backward_ops = backward_ops func_ctx.backward_ops = backward_ops
func_ctx.basic_ops = basic_ops func_ctx.basic_ops = basic_ops
func_ctx.basic_op_ctxs = basic_op_ctxs func_ctx.basic_op_ctxs = basic_op_ctxs
func_ctx.num_params = num_params func_ctx.basic_op_num_params = [sum(1 for _ in op.parameters()) for op in basic_ops]
func_ctx.num_extra_inputs = num_extra_inputs func_ctx.num_extra_inputs = num_extra_inputs
func_ctx.num_extra_outputs = len(extra_outputs_flat) func_ctx.num_extra_outputs = len(extra_outputs_flat)
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
...@@ -258,14 +258,14 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -258,14 +258,14 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# Flatten list of parameter gradients # Flatten list of parameter gradients
grad_params_flat = [] grad_params_flat = []
for idx, dparams in enumerate(grad_params): for idx, dparams in enumerate(grad_params):
params = list(basic_ops[idx].parameters()) num_params = func_ctx.basic_op_num_params[idx]
if dparams is None: if dparams is None:
dparams = [None for _ in range(len(params))] dparams = [None for _ in range(num_params)]
else: else:
dparams = list(dparams) dparams = list(dparams)
if len(dparams) != len(params): if len(dparams) != num_params:
raise RuntimeError( raise RuntimeError(
f"Expected op {idx} to generate {len(params)} param grads, " f"Expected op {idx} to generate {num_params} param grads, "
f"but got {len(dparams)}" f"but got {len(dparams)}"
) )
grad_params_flat.extend(dparams) grad_params_flat.extend(dparams)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment