Unverified Commit 3d7ff1c6 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Avoid `parameters` function in op backward pass (#1403)



* Avoid `parameters` function in op backward pass
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 7aa81186
......@@ -192,7 +192,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
func_ctx.backward_ops = backward_ops
func_ctx.basic_ops = basic_ops
func_ctx.basic_op_ctxs = basic_op_ctxs
func_ctx.num_params = num_params
func_ctx.basic_op_num_params = [sum(1 for _ in op.parameters()) for op in basic_ops]
func_ctx.num_extra_inputs = num_extra_inputs
func_ctx.num_extra_outputs = len(extra_outputs_flat)
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
......@@ -258,14 +258,14 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# Flatten list of parameter gradients
grad_params_flat = []
for idx, dparams in enumerate(grad_params):
params = list(basic_ops[idx].parameters())
num_params = func_ctx.basic_op_num_params[idx]
if dparams is None:
dparams = [None for _ in range(len(params))]
dparams = [None for _ in range(num_params)]
else:
dparams = list(dparams)
if len(dparams) != len(params):
if len(dparams) != num_params:
raise RuntimeError(
f"Expected op {idx} to generate {len(params)} param grads, "
f"Expected op {idx} to generate {num_params} param grads, "
f"but got {len(dparams)}"
)
grad_params_flat.extend(dparams)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment