Unverified Commit 21b780cc authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Enable use of internal tensors in Sequential (#1900)



* Replace `is_float8_tensor` with `is_quantized_tensor`

Replace free function `is_float8_tensor` with `is_quantized_tensor` in `_common.py` and use it throughout the `ops` codebase to check if a tensor is a (possibly internal) quantized tensor
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Pass next and previous op quantizers directly to op_forward and fuser_forward

Change interface of `fuser_forward` and `op_forward` to no longer take preceding and following ops and instead take the following op's input quantizer and preceding op's input gradient's quantizer directly
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Remove use redundant `detach` in `BasicLinear`

Remove use of `detach` in `BasicLinear` for improved performance (enabled by not passing prev_op to backward)
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Handle saving internal tensors

Handle saving internal tensors in `_OperationFuserAutogradFunction` using `prepare_for_saving` and `restore_from_saved`
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Use internal tensors

Enable use of internal tensors in `BasicLinear` quantizers and fix issues resulting from internal tensors not having methods that regular tensors have
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Apply suggestions from code review
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 447de6da
......@@ -20,11 +20,11 @@ from ...module.base import (
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer
from ...tensor.quantized_tensor import Quantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data
from ..basic import BasicLinear, Bias, ReduceScatter
from .._common import maybe_dequantize
from .._common import maybe_dequantize, is_quantized_tensor
from ..op import FusedOperation, FusibleOperation, OperationContext
......@@ -280,7 +280,7 @@ class UserbuffersBackwardLinear(FusedOperation):
# Cast grad output tensor dtype if needed
dy_local = grad_output
if with_quantized_compute:
if not isinstance(dy_local, QuantizedTensorBase):
if not is_quantized_tensor(dy_local):
with_columnwise = weight_requires_grad
if (
with_columnwise
......@@ -301,7 +301,7 @@ class UserbuffersBackwardLinear(FusedOperation):
raise ValueError("Weight tensor is required to compute input grad")
w = weight
if with_quantized_compute:
if not isinstance(w, QuantizedTensorBase):
if not is_quantized_tensor(w):
weight_quantizer.set_usage(columnwise=True)
w = weight_quantizer(w)
else:
......@@ -314,7 +314,7 @@ class UserbuffersBackwardLinear(FusedOperation):
raise ValueError("Input tensor is required to compute weight grad")
x_local = input
if with_quantized_compute:
if not isinstance(x_local, QuantizedTensorBase):
if not is_quantized_tensor(x_local):
input_quantizer.set_usage(columnwise=True)
x_local = input_quantizer(x_local)
else:
......@@ -425,7 +425,7 @@ class UserbuffersBackwardLinear(FusedOperation):
raise RuntimeError(
"wgrad GEMM requires grad output tensor, which has not been initialized"
)
if isinstance(dy, QuantizedTensorBase):
if is_quantized_tensor(dy):
dy.update_usage(rowwise_usage=False, columnwise_usage=True)
# Initialize input tensor
......@@ -435,7 +435,7 @@ class UserbuffersBackwardLinear(FusedOperation):
raise RuntimeError(
"wgrad GEMM requires input tensor, which has not been initialized"
)
if isinstance(x, QuantizedTensorBase):
if is_quantized_tensor(x):
x.update_usage(rowwise_usage=False, columnwise_usage=True)
# Check grad weight tensor
......
......@@ -20,14 +20,13 @@ from ...module.base import (
get_workspace,
_2X_ACC_FPROP,
)
from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer
from ...tensor.quantized_tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ...utils import canonicalize_device, canonicalize_dtype
from .._common import maybe_dequantize
from .._common import maybe_dequantize, is_quantized_tensor
from ..basic import BasicLinear, Bias, ReduceScatter
from ..op import (
BasicOperation,
FusedOperation,
FusibleOperation,
OperationContext,
......@@ -207,7 +206,7 @@ class UserbuffersForwardLinear(FusedOperation):
x = None
if with_ub_all_gather:
if input_quantizer is not None:
if not isinstance(x_local, QuantizedTensorBase):
if not is_quantized_tensor(x_local):
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
......@@ -223,7 +222,7 @@ class UserbuffersForwardLinear(FusedOperation):
)
else:
if with_quantized_compute:
if not isinstance(x_local, QuantizedTensorBase):
if not is_quantized_tensor(x_local):
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
x_local = input_quantizer(x_local)
else:
......@@ -234,7 +233,7 @@ class UserbuffersForwardLinear(FusedOperation):
w = weight
if not with_quantized_compute:
w = maybe_dequantize(w, dtype)
elif with_quantized_compute and not isinstance(w, QuantizedTensorBase):
elif with_quantized_compute and not is_quantized_tensor(w):
weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
w = weight_quantizer(w)
......@@ -266,18 +265,14 @@ class UserbuffersForwardLinear(FusedOperation):
# Prepare weight tensor for backward pass
if input_requires_grad:
if w is not weight and with_quantized_compute and isinstance(w, QuantizedTensorBase):
if w is not weight and with_quantized_compute and is_quantized_tensor(w):
w.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
w = None
# Prepare input tensor for backward pass
if weight_requires_grad:
if x_local is input:
# PyTorch autograd produces esoteric errors if we
# cache input tensor directly.
x_local = x_local.detach()
if with_quantized_compute and isinstance(x_local, QuantizedTensorBase):
if with_quantized_compute and is_quantized_tensor(x_local):
if not (isinstance(x_local, Float8TensorBase) and with_ub_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
......@@ -294,8 +289,9 @@ class UserbuffersForwardLinear(FusedOperation):
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]],
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
......@@ -313,7 +309,7 @@ class UserbuffersForwardLinear(FusedOperation):
raise ValueError("Bias operation forward does not expect keyword arguments")
# Check which grads are required
input_requires_grad = linear_op_ctx.requires_grad and input_.requires_grad
input_requires_grad = linear_op_ctx.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# Quantization metadata
......@@ -331,9 +327,7 @@ class UserbuffersForwardLinear(FusedOperation):
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
prev_op = basic_op_prev_ops[0]
if prev_op is not None and prev_op.num_quantizers("backward") > 0 and recipe.delayed():
grad_input_quantizer = prev_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_input_quantizer
# Get autocast dtype if needed
dtype = None
......@@ -376,7 +370,7 @@ class UserbuffersForwardLinear(FusedOperation):
linear_op_ctx.input_dims = input_.size()
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None
linear_op_ctx.has_prev_op = not is_first_op
return output, [() for _ in range(len(self.basic_ops))]
......
......@@ -23,6 +23,10 @@ from transformer_engine.pytorch.ops.fused import (
fuse_userbuffers_backward_linear,
fuse_userbuffers_forward_linear,
)
from transformer_engine.pytorch.tensor.quantized_tensor import (
prepare_for_saving,
restore_from_saved,
)
def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]:
......@@ -117,31 +121,33 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs)
for idx in basic_op_idxs:
basic_op_ctxs[idx].requires_grad = requires_grad
if requires_grad != x.requires_grad:
if requires_grad:
x.requires_grad_()
else:
x = x.detach()
# Forward op
extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
prev_ops = [fuser._basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs]
next_ops = [
fuser._basic_ops[idx + 1] if (idx < fuser._num_basic_ops - 1) else None
for idx in basic_op_idxs
]
prev_op_idx = basic_op_idxs[0] - 1
prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx > 0 else None
prev_op_grad_input_quantizer = None
if prev_op is not None:
prev_op_grad_input_quantizer = prev_op.get_grad_input_quantizer()
next_op_idx = basic_op_idxs[-1] + 1
next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None
next_op_input_quantizer = None
if next_op is not None:
next_op_input_quantizer = next_op.get_input_quantizer()
is_first_op = prev_op is None
x, fused_op_extra_outputs = op.fuser_forward(
[basic_op_ctxs[idx] for idx in basic_op_idxs],
x,
basic_op_extra_inputs=extra_inputs,
basic_op_prev_ops=prev_ops,
basic_op_next_ops=next_ops,
prev_op_grad_input_quantizer=prev_op_grad_input_quantizer,
next_op_input_quantizer=next_op_input_quantizer,
is_first_op=is_first_op,
basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs],
)
x.requires_grad_(requires_grad=requires_grad)
for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs):
for y in ys:
y.requires_grad_(requires_grad=requires_grad)
y.requires_grad_(requires_grad)
extra_outputs[idx] = ys
# Flatten list of extra outputs
......@@ -169,7 +175,15 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
range_end = len(to_save)
ctx.to_save = None
ctx._saved_tensors_range = (range_start, range_end)
func_ctx.save_for_backward(*to_save)
# Save tensors for backward
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
tensors_to_save, tensor_objects = prepare_for_saving(*to_save)
func_ctx.save_for_backward(*tensors_to_save)
func_ctx.tensor_objects = tensor_objects
else:
func_ctx.save_for_backward(*to_save)
# Other context
func_ctx.backward_ops = fuser._backward_ops
......@@ -179,9 +193,13 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
func_ctx.num_extra_inputs = fuser._num_extra_inputs
func_ctx.num_extra_outputs = len(extra_outputs_flat)
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
func_ctx.with_quantized_compute = with_quantized_compute
if extra_outputs_flat:
return x, *extra_outputs_flat
x.requires_grad_(requires_grad)
return x
@staticmethod
......@@ -198,8 +216,13 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
basic_ops = func_ctx.basic_ops
basic_op_ctxs = func_ctx.basic_op_ctxs
# Restore saved tensors
if func_ctx.with_quantized_compute:
saved_tensors = restore_from_saved(func_ctx.tensor_objects, func_ctx.saved_tensors)
else:
saved_tensors = func_ctx.saved_tensors
# Unflatten list of saved tensors
saved_tensors = func_ctx.saved_tensors
for ctx in basic_op_ctxs:
ctx.saved_tensors = saved_tensors[slice(*ctx._saved_tensors_range)]
ctx._saved_tensors_range = None
......
......@@ -68,14 +68,21 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
def pre_forward(self) -> None:
"""Preprocessing before forward pass"""
def get_input_quantizer(self) -> Optional[Quantizer]:
"""Get builder class for quantized input tensor"""
def get_grad_input_quantizer(self) -> Optional[Quantizer]:
"""Get builder class for quantized input's grad tensor"""
def fuser_forward(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]],
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
"""Forward pass
......@@ -94,12 +101,14 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
Input tensor
basic_op_extra_inputs: list of torch.Tensor
Extra tensor inputs to basic operations
basic_op_prev_ops: list of BasicOperation
Basic operations that preceed this operation's basic
operations
basic_op_next_ops: list of BasicOperation
Basic operations that follow this operation's basic
operations
prev_op_grad_input_quantizer: Quantizer, optional
The grad_input_quantizer of the preceeding operation
next_op_input_quantizer: Quantizer, optional
The input_quantizer of the following operation
is_first_op: bool
Does this op have a preceeding op or is it the first one in the
fuser. Used in the backward pass to safely delete the saved input
tensor when no longer needed and there is a preceeding op.
basic_op_kwargs: list of dict
Keyword arguments to forward functions of basic
operations.
......@@ -201,6 +210,16 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
"""
return 0
def get_input_quantizer(self) -> Optional[Quantizer]:
if self.num_quantizers("forward") > 0:
return self.get_quantizer("forward", 0)
return None
def get_grad_input_quantizer(self) -> Optional[Quantizer]:
if self.num_quantizers("backward") > 0:
return self.get_quantizer("backward", 0)
return None
def _reset_quantization_recipe_state(
self,
*,
......@@ -407,8 +426,9 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
ctx: OperationContext,
input_: torch.Tensor,
*,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
**kwargs: Any,
) -> torch.Tensor:
"""Forward pass
......@@ -419,10 +439,14 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
Context to coordinate between forward and backward passes
input_: torch.Tensor
Input tensor
prev_op: BasicOperation, optional
Basic operation that preceeds this operation
next_op: BasicOperation, optional
Basic operation that follows this operation
prev_op_grad_input_quantizer: Quantizer, optional
The grad_input_quantizer of the preceeding operation
next_op_input_quantizer: Quantizer, optional
The input_quantizer of the following operation
is_first_op: bool
Does this op have a preceeding op or is it the first one in the
fuser. Used in the backward pass to safely delete the saved input
tensor when no longer needed and there is a preceeding op.
Returns
-------
......@@ -461,8 +485,9 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]],
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, list[tuple[()]]]:
if self.num_extra_inputs > 0 or self.num_extra_outputs > 0:
......@@ -475,8 +500,9 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
output = self.op_forward(
basic_op_ctxs[0],
input_,
prev_op=basic_op_prev_ops[0],
next_op=basic_op_next_ops[0],
prev_op_grad_input_quantizer=prev_op_grad_input_quantizer,
next_op_input_quantizer=next_op_input_quantizer,
is_first_op=is_first_op,
**basic_op_kwargs[0],
)
return output, [()]
......@@ -696,6 +722,12 @@ class FusedOperation(FusibleOperation):
def is_fused_op(self) -> bool:
return True
def get_input_quantizer(self) -> Optional[Quantizer]:
return self.basic_ops[0].get_input_quantizer()
def get_grad_input_quantizer(self) -> Optional[Quantizer]:
return self.basic_ops[-1].get_grad_input_quantizer()
def pre_forward(self) -> None:
"""Preprocessing before forward pass"""
for op in self.basic_ops:
......
......@@ -143,6 +143,19 @@ class Float8TensorBase(QuantizedTensorBase):
size = self._transpose.size(*args, **kwargs)
return torch.Size([size[-1], math.prod(size[:-1])])
def view(self, shape: torch.Size):
# pylint: disable=missing-function-docstring
data = self._data
if data is not None:
return Float8TensorBase(
data=data.view(shape),
fp8_scale_inv=self._scale_inv,
fp8_dtype=self._fp8_dtype,
data_transpose=None,
quantizer=self._quantizer,
)
raise RuntimeError("No data available to view")
def __repr__(self):
return (
"Float8TensorBase("
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment