"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "7f5e4cb96e99a31b5498dbd7c03def85467338ca"
Unverified Commit 8d4bdbc2 authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Optimize `/ops/fuser.py` by moving computation from `forward` to `__init__` (#1870)



* Flatten basic op params during fuser init
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit 949abe97070721b1da5117903067608250f5fb61)

* Add caching for is_non_tn_fp8_gemm_supported
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit fd830ae24ffbd2d0727010b1a8a119ca72f61ce5)

* Pass fuser to _OperationFuserAutogradFunction.forward and moving computation to __init__
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit fd808991993958b670726896254b82fcb967fa07)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Pass basic_op_kwargs and is_grad_enabled as parameters rather than in fuser
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent d90ced7c
......@@ -61,13 +61,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
def forward(
func_ctx: Optional[torch.autograd.function.FunctionCtx],
input_: torch.Tensor,
forward_ops: list[tuple[FusibleOperation, list[int]]],
backward_ops: list[tuple[FusibleOperation, list[int]]],
basic_ops: list[BasicOperation],
fuser: OperationFuser,
basic_op_kwargs: list[dict[str, Any]],
is_grad_enabled: bool,
num_params: int,
num_extra_inputs: int,
*params_and_extra_inputs: torch.nn.Parameter,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Forward pass
......@@ -78,20 +74,12 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
Context for PyTorch autograd function
input_: torch.Tensor
Input to first operation in pipeline
forward_ops: list of tuple
Forward pass operations and the indices of the
corresponding basic operations. The order should match
basic_ops.
backward_ops: list of tuple
Backward pass operations and the indices of the
corresponding basic operations. The order should be the
reverse of basic_ops.
basic_ops: list of BasicOperation
Basic operations
fuser: OperationFuser
Container for the pipeline of operations to run
basic_op_kwargs: list of dict
Keyword arguments to BasicOperation
num_params: int
Number of parameter tensors to include in autograd graph.
is_grad_enabled: bool
Should context be saved for backward
*params_and_extra_inputs: torch.Tensor
Other tensor inputs to include in autograd graph. Consists
of parameter tensors, followed by extra operation inputs.
......@@ -106,26 +94,20 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
"""
# Operation autograd contexts
basic_op_ctxs = [OperationContext() for _ in range(len(basic_ops))]
basic_op_ctxs = [OperationContext() for _ in range(fuser._num_basic_ops)]
# Unflatten list of parameters and extra tensor inputs
if len(params_and_extra_inputs) != num_params + num_extra_inputs:
raise ValueError(
f"Expected {num_params + num_extra_inputs} extra tensor arguments "
f"({num_params} parameters, {num_extra_inputs} extra inputs), "
f"but got {len(params_and_extra_inputs)}"
)
_, extra_inputs = _split_tuple(params_and_extra_inputs, num_params)
extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs :]
basic_op_extra_inputs = []
for op in basic_ops:
for op in fuser._basic_ops:
xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs)
basic_op_extra_inputs.append(xs)
# Apply forward ops
x = input_
requires_grad = is_grad_enabled and x.requires_grad
extra_outputs = [None for _ in range(len(basic_ops))]
for op, basic_op_idxs in forward_ops:
extra_outputs = [None] * fuser._num_basic_ops
for op, basic_op_idxs in fuser._forward_ops:
# Check if backward op is required
if is_grad_enabled:
......@@ -143,9 +125,10 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# Forward op
extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
prev_ops = [basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs]
prev_ops = [fuser._basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs]
next_ops = [
basic_ops[idx + 1] if (idx < len(basic_ops) - 1) else None for idx in basic_op_idxs
fuser._basic_ops[idx + 1] if (idx < fuser._num_basic_ops - 1) else None
for idx in basic_op_idxs
]
x, fused_op_extra_outputs = op.fuser_forward(
[basic_op_ctxs[idx] for idx in basic_op_idxs],
......@@ -165,7 +148,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
extra_outputs_flat = []
for idx, ys in enumerate(extra_outputs):
ys = list(ys)
num_extra_outputs = basic_ops[idx].num_extra_outputs
num_extra_outputs = fuser._basic_ops[idx].num_extra_outputs
if len(ys) != num_extra_outputs:
raise RuntimeError(
f"Expected op {idx} to generate "
......@@ -189,11 +172,11 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
func_ctx.save_for_backward(*to_save)
# Other context
func_ctx.backward_ops = backward_ops
func_ctx.basic_ops = basic_ops
func_ctx.backward_ops = fuser._backward_ops
func_ctx.basic_ops = fuser._basic_ops
func_ctx.basic_op_ctxs = basic_op_ctxs
func_ctx.basic_op_num_params = [sum(1 for _ in op.parameters()) for op in basic_ops]
func_ctx.num_extra_inputs = num_extra_inputs
func_ctx.basic_op_num_params = fuser._num_list_basic_op_params
func_ctx.num_extra_inputs = fuser._num_extra_inputs
func_ctx.num_extra_outputs = len(extra_outputs_flat)
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
......@@ -293,13 +276,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
return (
dx, # input_
None, # forward_ops
None, # backward_ops
None, # basic_ops
None, # fuser
None, # basic_op_kwargs
None, # is_grad_enabled
None, # num_params
None, # num_extra_inputs
*grad_params_flat,
*grad_extra_inputs_flat,
)
......@@ -346,6 +325,10 @@ class OperationFuser:
if fuse_ops:
self.fuse_ops()
# Flatten list of parameters
self._basic_op_params = [param for op in self._basic_ops for param in op.parameters()]
self._num_list_basic_op_params = [sum(1 for _ in op.parameters()) for op in self._basic_ops]
@classmethod
def _fuse_forward_ops(
cls,
......@@ -378,6 +361,11 @@ class OperationFuser:
*extra_inputs: torch.Tensor,
basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
# Verify extra input count
if len(extra_inputs) != self._num_extra_inputs:
raise ValueError(
f"Expected {self._num_extra_inputs} extra inputs but got {len(extra_inputs)}"
)
# Initialization before forward pass
for op in self._basic_ops:
......@@ -385,10 +373,7 @@ class OperationFuser:
# Canonicalize op kwargs
if basic_op_kwargs is None:
basic_op_kwargs = [{} for _ in range(len(self._basic_ops))]
# Flatten list of parameters
params = [param for op in self._basic_ops for param in op.parameters()]
basic_op_kwargs = [{}] * self._num_basic_ops
# Fuser forward pass
is_grad_enabled = torch.is_grad_enabled()
......@@ -400,14 +385,10 @@ class OperationFuser:
args = [None]
args += (
input,
self._forward_ops,
self._backward_ops,
self._basic_ops,
self,
basic_op_kwargs,
is_grad_enabled,
len(params),
self._num_extra_inputs,
*params,
*self._basic_op_params,
*extra_inputs,
)
return forward_func(*args)
......@@ -448,6 +448,7 @@ def is_bf16_compatible() -> None:
return torch.cuda.get_device_capability()[0] >= 8
@functools.lru_cache(maxsize=None)
def is_non_tn_fp8_gemm_supported() -> bool:
"""Checks whether the device supports
non-TN layouts for FP8 GEMMs.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment