Unverified Commit 8d4bdbc2 authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Optimize `/ops/fuser.py` by moving computation from `forward` to `__init__` (#1870)



* Flatten basic op params during fuser init
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit 949abe97070721b1da5117903067608250f5fb61)

* Add caching for is_non_tn_fp8_gemm_supported
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit fd830ae24ffbd2d0727010b1a8a119ca72f61ce5)

* Pass fuser to _OperationFuserAutogradFunction.forward and moving computation to __init__
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit fd808991993958b670726896254b82fcb967fa07)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Pass basic_op_kwargs and is_grad_enabled as parameters rather than in fuser
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent d90ced7c
...@@ -61,13 +61,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -61,13 +61,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
def forward( def forward(
func_ctx: Optional[torch.autograd.function.FunctionCtx], func_ctx: Optional[torch.autograd.function.FunctionCtx],
input_: torch.Tensor, input_: torch.Tensor,
forward_ops: list[tuple[FusibleOperation, list[int]]], fuser: OperationFuser,
backward_ops: list[tuple[FusibleOperation, list[int]]],
basic_ops: list[BasicOperation],
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
is_grad_enabled: bool, is_grad_enabled: bool,
num_params: int,
num_extra_inputs: int,
*params_and_extra_inputs: torch.nn.Parameter, *params_and_extra_inputs: torch.nn.Parameter,
) -> torch.Tensor | tuple[torch.Tensor, ...]: ) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Forward pass """Forward pass
...@@ -78,20 +74,12 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -78,20 +74,12 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
Context for PyTorch autograd function Context for PyTorch autograd function
input_: torch.Tensor input_: torch.Tensor
Input to first operation in pipeline Input to first operation in pipeline
forward_ops: list of tuple fuser: OperationFuser
Forward pass operations and the indices of the Container for the pipeline of operations to run
corresponding basic operations. The order should match
basic_ops.
backward_ops: list of tuple
Backward pass operations and the indices of the
corresponding basic operations. The order should be the
reverse of basic_ops.
basic_ops: list of BasicOperation
Basic operations
basic_op_kwargs: list of dict basic_op_kwargs: list of dict
Keyword arguments to BasicOperation Keyword arguments to BasicOperation
num_params: int is_grad_enabled: bool
Number of parameter tensors to include in autograd graph. Should context be saved for backward
*params_and_extra_inputs: torch.Tensor *params_and_extra_inputs: torch.Tensor
Other tensor inputs to include in autograd graph. Consists Other tensor inputs to include in autograd graph. Consists
of parameter tensors, followed by extra operation inputs. of parameter tensors, followed by extra operation inputs.
...@@ -106,26 +94,20 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -106,26 +94,20 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
""" """
# Operation autograd contexts # Operation autograd contexts
basic_op_ctxs = [OperationContext() for _ in range(len(basic_ops))] basic_op_ctxs = [OperationContext() for _ in range(fuser._num_basic_ops)]
# Unflatten list of parameters and extra tensor inputs # Unflatten list of parameters and extra tensor inputs
if len(params_and_extra_inputs) != num_params + num_extra_inputs: extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs :]
raise ValueError(
f"Expected {num_params + num_extra_inputs} extra tensor arguments "
f"({num_params} parameters, {num_extra_inputs} extra inputs), "
f"but got {len(params_and_extra_inputs)}"
)
_, extra_inputs = _split_tuple(params_and_extra_inputs, num_params)
basic_op_extra_inputs = [] basic_op_extra_inputs = []
for op in basic_ops: for op in fuser._basic_ops:
xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs) xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs)
basic_op_extra_inputs.append(xs) basic_op_extra_inputs.append(xs)
# Apply forward ops # Apply forward ops
x = input_ x = input_
requires_grad = is_grad_enabled and x.requires_grad requires_grad = is_grad_enabled and x.requires_grad
extra_outputs = [None for _ in range(len(basic_ops))] extra_outputs = [None] * fuser._num_basic_ops
for op, basic_op_idxs in forward_ops: for op, basic_op_idxs in fuser._forward_ops:
# Check if backward op is required # Check if backward op is required
if is_grad_enabled: if is_grad_enabled:
...@@ -143,9 +125,10 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -143,9 +125,10 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# Forward op # Forward op
extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs] extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
prev_ops = [basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs] prev_ops = [fuser._basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs]
next_ops = [ next_ops = [
basic_ops[idx + 1] if (idx < len(basic_ops) - 1) else None for idx in basic_op_idxs fuser._basic_ops[idx + 1] if (idx < fuser._num_basic_ops - 1) else None
for idx in basic_op_idxs
] ]
x, fused_op_extra_outputs = op.fuser_forward( x, fused_op_extra_outputs = op.fuser_forward(
[basic_op_ctxs[idx] for idx in basic_op_idxs], [basic_op_ctxs[idx] for idx in basic_op_idxs],
...@@ -165,7 +148,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -165,7 +148,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
extra_outputs_flat = [] extra_outputs_flat = []
for idx, ys in enumerate(extra_outputs): for idx, ys in enumerate(extra_outputs):
ys = list(ys) ys = list(ys)
num_extra_outputs = basic_ops[idx].num_extra_outputs num_extra_outputs = fuser._basic_ops[idx].num_extra_outputs
if len(ys) != num_extra_outputs: if len(ys) != num_extra_outputs:
raise RuntimeError( raise RuntimeError(
f"Expected op {idx} to generate " f"Expected op {idx} to generate "
...@@ -189,11 +172,11 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -189,11 +172,11 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
func_ctx.save_for_backward(*to_save) func_ctx.save_for_backward(*to_save)
# Other context # Other context
func_ctx.backward_ops = backward_ops func_ctx.backward_ops = fuser._backward_ops
func_ctx.basic_ops = basic_ops func_ctx.basic_ops = fuser._basic_ops
func_ctx.basic_op_ctxs = basic_op_ctxs func_ctx.basic_op_ctxs = basic_op_ctxs
func_ctx.basic_op_num_params = [sum(1 for _ in op.parameters()) for op in basic_ops] func_ctx.basic_op_num_params = fuser._num_list_basic_op_params
func_ctx.num_extra_inputs = num_extra_inputs func_ctx.num_extra_inputs = fuser._num_extra_inputs
func_ctx.num_extra_outputs = len(extra_outputs_flat) func_ctx.num_extra_outputs = len(extra_outputs_flat)
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
...@@ -293,13 +276,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -293,13 +276,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
return ( return (
dx, # input_ dx, # input_
None, # forward_ops None, # fuser
None, # backward_ops
None, # basic_ops
None, # basic_op_kwargs None, # basic_op_kwargs
None, # is_grad_enabled None, # is_grad_enabled
None, # num_params
None, # num_extra_inputs
*grad_params_flat, *grad_params_flat,
*grad_extra_inputs_flat, *grad_extra_inputs_flat,
) )
...@@ -346,6 +325,10 @@ class OperationFuser: ...@@ -346,6 +325,10 @@ class OperationFuser:
if fuse_ops: if fuse_ops:
self.fuse_ops() self.fuse_ops()
# Flatten list of parameters
self._basic_op_params = [param for op in self._basic_ops for param in op.parameters()]
self._num_list_basic_op_params = [sum(1 for _ in op.parameters()) for op in self._basic_ops]
@classmethod @classmethod
def _fuse_forward_ops( def _fuse_forward_ops(
cls, cls,
...@@ -378,6 +361,11 @@ class OperationFuser: ...@@ -378,6 +361,11 @@ class OperationFuser:
*extra_inputs: torch.Tensor, *extra_inputs: torch.Tensor,
basic_op_kwargs: Optional[list[dict[str, Any]]] = None, basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
) -> torch.Tensor | tuple[torch.Tensor, ...]: ) -> torch.Tensor | tuple[torch.Tensor, ...]:
# Verify extra input count
if len(extra_inputs) != self._num_extra_inputs:
raise ValueError(
f"Expected {self._num_extra_inputs} extra inputs but got {len(extra_inputs)}"
)
# Initialization before forward pass # Initialization before forward pass
for op in self._basic_ops: for op in self._basic_ops:
...@@ -385,10 +373,7 @@ class OperationFuser: ...@@ -385,10 +373,7 @@ class OperationFuser:
# Canonicalize op kwargs # Canonicalize op kwargs
if basic_op_kwargs is None: if basic_op_kwargs is None:
basic_op_kwargs = [{} for _ in range(len(self._basic_ops))] basic_op_kwargs = [{}] * self._num_basic_ops
# Flatten list of parameters
params = [param for op in self._basic_ops for param in op.parameters()]
# Fuser forward pass # Fuser forward pass
is_grad_enabled = torch.is_grad_enabled() is_grad_enabled = torch.is_grad_enabled()
...@@ -400,14 +385,10 @@ class OperationFuser: ...@@ -400,14 +385,10 @@ class OperationFuser:
args = [None] args = [None]
args += ( args += (
input, input,
self._forward_ops, self,
self._backward_ops,
self._basic_ops,
basic_op_kwargs, basic_op_kwargs,
is_grad_enabled, is_grad_enabled,
len(params), *self._basic_op_params,
self._num_extra_inputs,
*params,
*extra_inputs, *extra_inputs,
) )
return forward_func(*args) return forward_func(*args)
...@@ -448,6 +448,7 @@ def is_bf16_compatible() -> None: ...@@ -448,6 +448,7 @@ def is_bf16_compatible() -> None:
return torch.cuda.get_device_capability()[0] >= 8 return torch.cuda.get_device_capability()[0] >= 8
@functools.lru_cache(maxsize=None)
def is_non_tn_fp8_gemm_supported() -> bool: def is_non_tn_fp8_gemm_supported() -> bool:
"""Checks whether the device supports """Checks whether the device supports
non-TN layouts for FP8 GEMMs. non-TN layouts for FP8 GEMMs.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment