Unverified Commit a3df1d73 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Prototype for operation-based API (#707)



* Add basic infrastructure for Sequential module
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add linear op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add FP8 support in linear op

Runs, but need to validate. Runtime errors with non-FP8 params and FP8 compute, or FP8 params and non-FP8 compute.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add reshape op and unit test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add bias op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add unfused linear op

Test does not pass with FP8.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug unfused linear op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add test for linear+bias op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add separate abstract classes for unfused and fused ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Consolidate unfused ops in submodule
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add linear-bias fused op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use fused cast-transpose in linear ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable GEMM+bias fusion with FP32 activations

Not supported by cuBLAS.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add parallel unit test for unfused linear op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Refactor parallel tests to reduce job launches
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add all-reduce, all-gather, and reduce-scatter ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove unused file
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug multi-GPU FP8 test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add support for FP8 scale updates

Still need to implement amax reductions.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add license boilerplate
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fuse GEMM+bias in row TP

Add documentation for unfused ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Rename pipeline to fuser

Expand documentation
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Tweak documentation
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Preserve cached FP8 transpose between ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add option for fused wgrad accumulation
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Directly output FP8 from linear if needed
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix cuDNN front-end commit
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use updated FP8 tensor API for transpose caching
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use updated API for FP8 scale updates
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add tests for non-default FP8 recipes
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Rename UnfusedOperation to BasicOperation
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add unit test to check amax reduction with fusable op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Operator autograd state no longer needs to be initialized
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Initial functional implementation of linear op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug fused linear+bias op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove autograd context from functional linear impl
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use functional linear impl in fused linear+bias op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Rename subdirectory from "fuser" to "ops"

Avoid confusion with kernel fusers and graph compilers.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update with Float8Tensor changes in #820
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove unnecessary CPU overheads
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Correctly pass FP8 metadata from next op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter errors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add convenience functions to manipulate Sequential class
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update name of PyTorch extensions module
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Clear saved tensor data in linear op after bprop
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix Pylint error
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update name of PyTorch extensions module
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix test name in QA script
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update name of PyTorch extensions module
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Run distributed tests even when only 1 GPU is available
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Only run distributed tests with 2 GPUs if there are >=2 GPUs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestions from @sudhakarsingh27 and @ksivaman

Fix spelling of "fusible". Avoid "input" name in internal APIs.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update transformer_engine/pytorch/ops/__init__.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 05977f44
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Base classes for fusible operations."""
from __future__ import annotations
import abc
from collections.abc import Iterable
import dataclasses
from typing import Any, Optional
import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import (
FP8GlobalStateManager,
get_default_fp8_recipe,
)
from ._common import canonicalize_device, is_float8_tensor
@dataclasses.dataclass
class OperationContext:
"""State needed to apply an operation
Saves state from forward pass for use in backward pass.
"""
# Tensors that have been saved from forward function
# Note: Available in the backward function, matching tensors from
# to_save.
saved_tensors: Optional[tuple[Optional[torch.Tensor], ...]] = None
# Tensors to save for backward function
# Note: Expected to be set in the forward function, either
# directly or with save_for_backward.
to_save: Optional[tuple[Optional[torch.Tensor], ...]] = None
# Corresponding range in pipeline's list of saved tensors
_saved_tensors_range: Optional[tuple[int, int]] = None
# Whether backward pass is required
_requires_grad: bool = False
def save_for_backward(self, *tensors: Optional[torch.Tensor]) -> None:
"""Register tensors to be saved for the backward function
Expected to be called in the forward function.
"""
self.to_save = tensors
class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
"""Tensor operation supported by the operation fuser"""
@property
@abc.abstractmethod
def is_fused_op(self) -> bool:
"""Whether this op is the fusion of one or more basic ops"""
def pre_forward(self) -> None:
"""Preprocessing before forward pass"""
def fuser_forward(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]],
basic_op_kwargs: list[dict[str, Any]],
) -> torch.Tensor:
"""Forward pass
This op is either a basic op or the fusion of basic ops, so
several of this function's arguments are lists of arguments to
forward functions of corresponding basic ops.
Called by `OperationFuser`.
Parameters
----------
basic_op_ctxs: list of OperationContext
Contexts for corresponding basic operations
input_: torch.Tensor
Input tensor
basic_op_prev_ops: list of BasicOperation
Basic operations that preceed each of the corresponding
basic operations (or `None` if corresponding basic op is
first)
basic_op_next_ops: list of BasicOperation
Basic operations that follow each of the corresponding
basic operations (or `None` if corresponding basic op is
last)
basic_op_kwargs: list of dict
Keyword arguments to forward functions of corresponding
basic operations
Returns
-------
torch.Tensor: Output tensor.
"""
raise NotImplementedError(
f"Forward pass is not implemented for operation ({self.__class__.__name__})"
)
def fuser_backward(
self,
basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, Iterable[Iterable[Optional[torch.Tensor]]]]:
"""Backward pass
This op is either a basic op or the fusion of basic ops, so
several of this function's arguments are lists of arguments to
backward functions of corresponding basic ops.
Called by `OperationFuser`.
Parameters
----------
basic_op_ctxs: list of OperationContext
Contexts for corresponding basic operations.
grad_output: torch.Tensor
Loss gradient w.r.t. operation output.
basic_op_prev_ops: list of BasicOperation
Basic operations that preceed each of the corresponding
basic operations (or `None` if corresponding basic op is
first)
basic_op_next_ops: list of BasicOperation
Basic operations that follow each of the corresponding
basic operations (or `None` if corresponding basic op is
last)
Returns
-------
torch.Tensor:
Loss gradient w.r.t. operation input
Iterable of iterable of torch.Tensor:
Loss gradients w.r.t. parameters for corresponding basic
operations
"""
raise NotImplementedError(
f"Backward pass is not implemented for operation ({self.__class__.__name__})"
)
class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
"""Single tensor operation supported by the operation fuser
This class holds parameters and state, even if the actual forward
and backward passes are performed by a fused operation.
"""
def __init__(self) -> None:
super().__init__()
# FP8 metadata objects
self._fp8_metas: Optional[dict[str, dict[str, Any]]] = None
@property
def is_fused_op(self) -> bool:
return False
# pylint: disable=no-self-use
def num_fp8_scales(
self,
mode: str, # pylint: disable=unused-argument
) -> int:
"""Number of FP8 scaling factors
Parameters
----------
mode: {"input", "param", "grad_output"}
Type of FP8 scaling factor
"""
return 0
def _make_fp8_metas(self) -> dict[str, Optional[dict[str, Any]]]:
"""Construct FP8 metadata"""
# Shared objects for FP8 metadata
dtype = torch.float32
device = canonicalize_device(None)
recipe = get_default_fp8_recipe()
def _make_meta(
num_scales: int,
is_forward: bool,
) -> Optional[dict[str, Any]]:
"""Construct FP8 metadata for one tensor type"""
if num_scales == 0:
return None
key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward)
meta = tex.FP8TensorMeta()
meta.scale = torch.ones(num_scales, dtype=dtype, device=device)
meta.scale_inv = torch.ones(num_scales, dtype=dtype, device=device)
meta.amax_history = torch.zeros(
(recipe.amax_history_len, num_scales),
dtype=dtype,
device=device,
)
return {
key: meta,
"recipe": recipe,
"fp8_group": None,
}
# Construct FP8 metadata for all tensor types
return dict(
input=_make_meta(self.num_fp8_scales("input"), True),
param=_make_meta(self.num_fp8_scales("param"), True),
grad_output=_make_meta(self.num_fp8_scales("grad_output"), False),
)
@classmethod
def _maybe_update_fp8_meta(cls, fp8_meta: Optional[dict[str, Any]]) -> None:
if fp8_meta is None:
return
# Update FP8 recipe and communication group
recipe = FP8GlobalStateManager.get_fp8_recipe()
fp8_meta["recipe"] = recipe
fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
# Adjust amax history length if needed
amax_history_len = recipe.amax_history_len
for is_forward in (True, False):
key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward)
if key not in fp8_meta:
continue
meta = fp8_meta[key]
curr_len = meta.amax_history.size(0)
if curr_len == amax_history_len:
continue
with torch.no_grad():
if curr_len > amax_history_len:
meta.amax_history = meta.amax_history[:amax_history_len].clone()
else:
meta.amax_history = torch.nn.functional.pad(
meta.amax_history,
pad=(0, 0, 0, amax_history_len - curr_len),
)
def get_fp8_meta(self, mode: str) -> Optional[dict[str, Any]]:
"""FP8 metadata
Parameters
----------
mode: {"input", "param", "grad_output"}
Type of FP8 scaling factor
"""
if self._fp8_metas is None:
self._fp8_metas = self._make_fp8_metas()
return self._fp8_metas[mode]
def pre_forward(self) -> None:
"""Preprocessing before forward pass"""
# Initialize FP8 metadata if needed
fp8_enabled = FP8GlobalStateManager.is_fp8_enabled()
if fp8_enabled:
# Construct FP8 metadata if needed
if self._fp8_metas is None:
self._fp8_metas = self._make_fp8_metas()
# Make sure FP8 metadata matches FP8 autocast context
for fp8_meta in self._fp8_metas.values():
self._maybe_update_fp8_meta(fp8_meta)
# Register FP8 metadata for amax and scale update
if not FP8GlobalStateManager.fp8_graph_capturing():
if self.num_fp8_scales("input"):
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
self.get_fp8_meta("input"),
)
if self.num_fp8_scales("param"):
fp8_params = list(filter(is_float8_tensor, self.parameters()))
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
self.get_fp8_meta("param"),
fp8_weights=(fp8_params if fp8_params else None),
)
if self.num_fp8_scales("grad_output"):
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
self.get_fp8_meta("grad_output"),
)
@abc.abstractmethod
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
**kwargs: Any,
) -> torch.Tensor:
"""Forward pass
Parameters
----------
ctx: OperationContext
Context to coordinate between forward and backward passes
input_: torch.Tensor
Input tensor
Returns
-------
torch.Tensor:
Output tensor
"""
@abc.abstractmethod
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, Iterable[Optional[torch.Tensor]]]:
"""Backward pass
Parameters
----------
ctx: OperationContext
Context to coordinate between forward and backward passes
grad_output: torch.Tensor
Loss gradient w.r.t. operation output
Returns
-------
torch.Tensor
Loss gradient w.r.t. operation input
Iterable of torch.Tensor:
Loss gradients w.r.t. parameters
"""
def fuser_forward(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]],
basic_op_kwargs: list[dict[str, Any]],
) -> torch.Tensor:
return self.op_forward(
basic_op_ctxs[0],
input_,
basic_op_prev_ops[0],
basic_op_next_ops[0],
**basic_op_kwargs[0],
)
def fuser_backward(
self,
basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, Iterable[Iterable[Optional[torch.Tensor]]]]:
grad_input, grad_params = self.op_backward(basic_op_ctxs[0], grad_output)
return grad_input, [grad_params]
def forward(
self,
input: torch.Tensor, # pylint: disable=redefined-builtin
**kwargs: Any,
) -> torch.Tensor:
"""Apply operation"""
from .fuser import OperationFuser
return OperationFuser([self], fuse_ops=False)(input, [kwargs])
class FusedOperation(FusibleOperation):
"""Compound tensor operation supported by the operation fuser
If the forward or backward passes are defined, they must be
functionally equivalent to the forward/backward passes of the
corresponding basic ops. This class should hold no parameters or
other state, but should access them from the basic ops.
Parameters
----------
basic_ops: iterable of FusibleOperation
Basic ops that are interchangeable with this op
"""
def __init__(
self,
basic_ops: Iterable[FusibleOperation],
) -> None:
super().__init__()
# Basic operations that comprise this fused operation
self.basic_ops: torch.nn.ModuleList = torch.nn.ModuleList(basic_ops)
if len(self.basic_ops) == 0:
raise ValueError(
"Attempted to construct a fused operation "
"without specifying its corresponding basic operations"
)
@property
def is_fused_op(self) -> bool:
return True
def pre_forward(self) -> None:
"""Preprocessing before forward pass"""
for op in self.basic_ops:
op.pre_forward()
def forward(
self,
input: torch.Tensor, # pylint: disable=redefined-builtin
basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
) -> torch.Tensor:
"""Apply operation"""
if basic_op_kwargs is None:
basic_op_kwargs = [{} for _ in range(len(self.basic_ops))]
from .fuser import OperationFuser
return OperationFuser([self], fuse_ops=False)(input, basic_op_kwargs)
This diff is collapsed.
...@@ -32,7 +32,6 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: ...@@ -32,7 +32,6 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
if t is not None: if t is not None:
if isinstance(t, Float8Tensor): if isinstance(t, Float8Tensor):
t._data.data = torch.Tensor() t._data.data = torch.Tensor()
del t
else: else:
t.data = torch.Tensor() t.data = torch.Tensor()
del t del t
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment