Unverified Commit 21b780cc authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Enable use of internal tensors in Sequential (#1900)



* Replace `is_float8_tensor` with `is_quantized_tensor`

Replace free function `is_float8_tensor` with `is_quantized_tensor` in `_common.py` and use it throughout the `ops` codebase to check if a tensor is a (possibly internal) quantized tensor
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Pass next and previous op quantizers directly to op_forward and fuser_forward

Change interface of `fuser_forward` and `op_forward` to no longer take preceding and following ops and instead take the following op's input quantizer and preceding op's input gradient's quantizer directly
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Remove use redundant `detach` in `BasicLinear`

Remove use of `detach` in `BasicLinear` for improved performance (enabled by not passing prev_op to backward)
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Handle saving internal tensors

Handle saving internal tensors in `_OperationFuserAutogradFunction` using `prepare_for_saving` and `restore_from_saved`
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Use internal tensors

Enable use of internal tensors in `BasicLinear` quantizers and fix issues resulting from internal tensors not having methods that regular tensors have
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Apply suggestions from code review
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 447de6da
...@@ -20,11 +20,11 @@ from ...module.base import ( ...@@ -20,11 +20,11 @@ from ...module.base import (
_2X_ACC_DGRAD, _2X_ACC_DGRAD,
_2X_ACC_WGRAD, _2X_ACC_WGRAD,
) )
from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer from ...tensor.quantized_tensor import Quantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data
from ..basic import BasicLinear, Bias, ReduceScatter from ..basic import BasicLinear, Bias, ReduceScatter
from .._common import maybe_dequantize from .._common import maybe_dequantize, is_quantized_tensor
from ..op import FusedOperation, FusibleOperation, OperationContext from ..op import FusedOperation, FusibleOperation, OperationContext
...@@ -280,7 +280,7 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -280,7 +280,7 @@ class UserbuffersBackwardLinear(FusedOperation):
# Cast grad output tensor dtype if needed # Cast grad output tensor dtype if needed
dy_local = grad_output dy_local = grad_output
if with_quantized_compute: if with_quantized_compute:
if not isinstance(dy_local, QuantizedTensorBase): if not is_quantized_tensor(dy_local):
with_columnwise = weight_requires_grad with_columnwise = weight_requires_grad
if ( if (
with_columnwise with_columnwise
...@@ -301,7 +301,7 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -301,7 +301,7 @@ class UserbuffersBackwardLinear(FusedOperation):
raise ValueError("Weight tensor is required to compute input grad") raise ValueError("Weight tensor is required to compute input grad")
w = weight w = weight
if with_quantized_compute: if with_quantized_compute:
if not isinstance(w, QuantizedTensorBase): if not is_quantized_tensor(w):
weight_quantizer.set_usage(columnwise=True) weight_quantizer.set_usage(columnwise=True)
w = weight_quantizer(w) w = weight_quantizer(w)
else: else:
...@@ -314,7 +314,7 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -314,7 +314,7 @@ class UserbuffersBackwardLinear(FusedOperation):
raise ValueError("Input tensor is required to compute weight grad") raise ValueError("Input tensor is required to compute weight grad")
x_local = input x_local = input
if with_quantized_compute: if with_quantized_compute:
if not isinstance(x_local, QuantizedTensorBase): if not is_quantized_tensor(x_local):
input_quantizer.set_usage(columnwise=True) input_quantizer.set_usage(columnwise=True)
x_local = input_quantizer(x_local) x_local = input_quantizer(x_local)
else: else:
...@@ -425,7 +425,7 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -425,7 +425,7 @@ class UserbuffersBackwardLinear(FusedOperation):
raise RuntimeError( raise RuntimeError(
"wgrad GEMM requires grad output tensor, which has not been initialized" "wgrad GEMM requires grad output tensor, which has not been initialized"
) )
if isinstance(dy, QuantizedTensorBase): if is_quantized_tensor(dy):
dy.update_usage(rowwise_usage=False, columnwise_usage=True) dy.update_usage(rowwise_usage=False, columnwise_usage=True)
# Initialize input tensor # Initialize input tensor
...@@ -435,7 +435,7 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -435,7 +435,7 @@ class UserbuffersBackwardLinear(FusedOperation):
raise RuntimeError( raise RuntimeError(
"wgrad GEMM requires input tensor, which has not been initialized" "wgrad GEMM requires input tensor, which has not been initialized"
) )
if isinstance(x, QuantizedTensorBase): if is_quantized_tensor(x):
x.update_usage(rowwise_usage=False, columnwise_usage=True) x.update_usage(rowwise_usage=False, columnwise_usage=True)
# Check grad weight tensor # Check grad weight tensor
......
...@@ -20,14 +20,13 @@ from ...module.base import ( ...@@ -20,14 +20,13 @@ from ...module.base import (
get_workspace, get_workspace,
_2X_ACC_FPROP, _2X_ACC_FPROP,
) )
from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer from ...tensor.quantized_tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ...utils import canonicalize_device, canonicalize_dtype from ...utils import canonicalize_device, canonicalize_dtype
from .._common import maybe_dequantize from .._common import maybe_dequantize, is_quantized_tensor
from ..basic import BasicLinear, Bias, ReduceScatter from ..basic import BasicLinear, Bias, ReduceScatter
from ..op import ( from ..op import (
BasicOperation,
FusedOperation, FusedOperation,
FusibleOperation, FusibleOperation,
OperationContext, OperationContext,
...@@ -207,7 +206,7 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -207,7 +206,7 @@ class UserbuffersForwardLinear(FusedOperation):
x = None x = None
if with_ub_all_gather: if with_ub_all_gather:
if input_quantizer is not None: if input_quantizer is not None:
if not isinstance(x_local, QuantizedTensorBase): if not is_quantized_tensor(x_local):
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
if isinstance( if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
...@@ -223,7 +222,7 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -223,7 +222,7 @@ class UserbuffersForwardLinear(FusedOperation):
) )
else: else:
if with_quantized_compute: if with_quantized_compute:
if not isinstance(x_local, QuantizedTensorBase): if not is_quantized_tensor(x_local):
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
x_local = input_quantizer(x_local) x_local = input_quantizer(x_local)
else: else:
...@@ -234,7 +233,7 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -234,7 +233,7 @@ class UserbuffersForwardLinear(FusedOperation):
w = weight w = weight
if not with_quantized_compute: if not with_quantized_compute:
w = maybe_dequantize(w, dtype) w = maybe_dequantize(w, dtype)
elif with_quantized_compute and not isinstance(w, QuantizedTensorBase): elif with_quantized_compute and not is_quantized_tensor(w):
weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
w = weight_quantizer(w) w = weight_quantizer(w)
...@@ -266,18 +265,14 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -266,18 +265,14 @@ class UserbuffersForwardLinear(FusedOperation):
# Prepare weight tensor for backward pass # Prepare weight tensor for backward pass
if input_requires_grad: if input_requires_grad:
if w is not weight and with_quantized_compute and isinstance(w, QuantizedTensorBase): if w is not weight and with_quantized_compute and is_quantized_tensor(w):
w.update_usage(rowwise_usage=False, columnwise_usage=True) w.update_usage(rowwise_usage=False, columnwise_usage=True)
else: else:
w = None w = None
# Prepare input tensor for backward pass # Prepare input tensor for backward pass
if weight_requires_grad: if weight_requires_grad:
if x_local is input: if with_quantized_compute and is_quantized_tensor(x_local):
# PyTorch autograd produces esoteric errors if we
# cache input tensor directly.
x_local = x_local.detach()
if with_quantized_compute and isinstance(x_local, QuantizedTensorBase):
if not (isinstance(x_local, Float8TensorBase) and with_ub_all_gather): if not (isinstance(x_local, Float8TensorBase) and with_ub_all_gather):
# FP8 does not support all-gather of transpose data # FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True) x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
...@@ -294,8 +289,9 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -294,8 +289,9 @@ class UserbuffersForwardLinear(FusedOperation):
input_: torch.Tensor, input_: torch.Tensor,
*, *,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]], prev_op_grad_input_quantizer: Optional[Quantizer],
basic_op_next_ops: list[Optional[BasicOperation]], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
...@@ -313,7 +309,7 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -313,7 +309,7 @@ class UserbuffersForwardLinear(FusedOperation):
raise ValueError("Bias operation forward does not expect keyword arguments") raise ValueError("Bias operation forward does not expect keyword arguments")
# Check which grads are required # Check which grads are required
input_requires_grad = linear_op_ctx.requires_grad and input_.requires_grad input_requires_grad = linear_op_ctx.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# Quantization metadata # Quantization metadata
...@@ -331,9 +327,7 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -331,9 +327,7 @@ class UserbuffersForwardLinear(FusedOperation):
input_quantizer = linear_op.get_quantizer("forward", 0) input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1) weight_quantizer = linear_op.get_quantizer("forward", 1)
grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_output_quantizer = linear_op.get_quantizer("backward", 0)
prev_op = basic_op_prev_ops[0] grad_input_quantizer = prev_op_grad_input_quantizer
if prev_op is not None and prev_op.num_quantizers("backward") > 0 and recipe.delayed():
grad_input_quantizer = prev_op.get_quantizer("backward", 0)
# Get autocast dtype if needed # Get autocast dtype if needed
dtype = None dtype = None
...@@ -376,7 +370,7 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -376,7 +370,7 @@ class UserbuffersForwardLinear(FusedOperation):
linear_op_ctx.input_dims = input_.size() linear_op_ctx.input_dims = input_.size()
linear_op_ctx.input_requires_grad = input_requires_grad linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None linear_op_ctx.has_prev_op = not is_first_op
return output, [() for _ in range(len(self.basic_ops))] return output, [() for _ in range(len(self.basic_ops))]
......
...@@ -23,6 +23,10 @@ from transformer_engine.pytorch.ops.fused import ( ...@@ -23,6 +23,10 @@ from transformer_engine.pytorch.ops.fused import (
fuse_userbuffers_backward_linear, fuse_userbuffers_backward_linear,
fuse_userbuffers_forward_linear, fuse_userbuffers_forward_linear,
) )
from transformer_engine.pytorch.tensor.quantized_tensor import (
prepare_for_saving,
restore_from_saved,
)
def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]: def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]:
...@@ -117,31 +121,33 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -117,31 +121,33 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs) requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs)
for idx in basic_op_idxs: for idx in basic_op_idxs:
basic_op_ctxs[idx].requires_grad = requires_grad basic_op_ctxs[idx].requires_grad = requires_grad
if requires_grad != x.requires_grad:
if requires_grad:
x.requires_grad_()
else:
x = x.detach()
# Forward op # Forward op
extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs] extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
prev_ops = [fuser._basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs] prev_op_idx = basic_op_idxs[0] - 1
next_ops = [ prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx > 0 else None
fuser._basic_ops[idx + 1] if (idx < fuser._num_basic_ops - 1) else None prev_op_grad_input_quantizer = None
for idx in basic_op_idxs if prev_op is not None:
] prev_op_grad_input_quantizer = prev_op.get_grad_input_quantizer()
next_op_idx = basic_op_idxs[-1] + 1
next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None
next_op_input_quantizer = None
if next_op is not None:
next_op_input_quantizer = next_op.get_input_quantizer()
is_first_op = prev_op is None
x, fused_op_extra_outputs = op.fuser_forward( x, fused_op_extra_outputs = op.fuser_forward(
[basic_op_ctxs[idx] for idx in basic_op_idxs], [basic_op_ctxs[idx] for idx in basic_op_idxs],
x, x,
basic_op_extra_inputs=extra_inputs, basic_op_extra_inputs=extra_inputs,
basic_op_prev_ops=prev_ops, prev_op_grad_input_quantizer=prev_op_grad_input_quantizer,
basic_op_next_ops=next_ops, next_op_input_quantizer=next_op_input_quantizer,
is_first_op=is_first_op,
basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs], basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs],
) )
x.requires_grad_(requires_grad=requires_grad)
for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs): for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs):
for y in ys: for y in ys:
y.requires_grad_(requires_grad=requires_grad) y.requires_grad_(requires_grad)
extra_outputs[idx] = ys extra_outputs[idx] = ys
# Flatten list of extra outputs # Flatten list of extra outputs
...@@ -169,7 +175,15 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -169,7 +175,15 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
range_end = len(to_save) range_end = len(to_save)
ctx.to_save = None ctx.to_save = None
ctx._saved_tensors_range = (range_start, range_end) ctx._saved_tensors_range = (range_start, range_end)
func_ctx.save_for_backward(*to_save)
# Save tensors for backward
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
tensors_to_save, tensor_objects = prepare_for_saving(*to_save)
func_ctx.save_for_backward(*tensors_to_save)
func_ctx.tensor_objects = tensor_objects
else:
func_ctx.save_for_backward(*to_save)
# Other context # Other context
func_ctx.backward_ops = fuser._backward_ops func_ctx.backward_ops = fuser._backward_ops
...@@ -179,9 +193,13 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -179,9 +193,13 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
func_ctx.num_extra_inputs = fuser._num_extra_inputs func_ctx.num_extra_inputs = fuser._num_extra_inputs
func_ctx.num_extra_outputs = len(extra_outputs_flat) func_ctx.num_extra_outputs = len(extra_outputs_flat)
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
func_ctx.with_quantized_compute = with_quantized_compute
if extra_outputs_flat: if extra_outputs_flat:
return x, *extra_outputs_flat return x, *extra_outputs_flat
x.requires_grad_(requires_grad)
return x return x
@staticmethod @staticmethod
...@@ -198,8 +216,13 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -198,8 +216,13 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
basic_ops = func_ctx.basic_ops basic_ops = func_ctx.basic_ops
basic_op_ctxs = func_ctx.basic_op_ctxs basic_op_ctxs = func_ctx.basic_op_ctxs
# Restore saved tensors
if func_ctx.with_quantized_compute:
saved_tensors = restore_from_saved(func_ctx.tensor_objects, func_ctx.saved_tensors)
else:
saved_tensors = func_ctx.saved_tensors
# Unflatten list of saved tensors # Unflatten list of saved tensors
saved_tensors = func_ctx.saved_tensors
for ctx in basic_op_ctxs: for ctx in basic_op_ctxs:
ctx.saved_tensors = saved_tensors[slice(*ctx._saved_tensors_range)] ctx.saved_tensors = saved_tensors[slice(*ctx._saved_tensors_range)]
ctx._saved_tensors_range = None ctx._saved_tensors_range = None
......
...@@ -68,14 +68,21 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta): ...@@ -68,14 +68,21 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
def pre_forward(self) -> None: def pre_forward(self) -> None:
"""Preprocessing before forward pass""" """Preprocessing before forward pass"""
def get_input_quantizer(self) -> Optional[Quantizer]:
"""Get builder class for quantized input tensor"""
def get_grad_input_quantizer(self) -> Optional[Quantizer]:
"""Get builder class for quantized input's grad tensor"""
def fuser_forward( def fuser_forward(
self, self,
basic_op_ctxs: list[OperationContext], basic_op_ctxs: list[OperationContext],
input_: torch.Tensor, input_: torch.Tensor,
*, *,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]], prev_op_grad_input_quantizer: Optional[Quantizer],
basic_op_next_ops: list[Optional[BasicOperation]], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
"""Forward pass """Forward pass
...@@ -94,12 +101,14 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta): ...@@ -94,12 +101,14 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
Input tensor Input tensor
basic_op_extra_inputs: list of torch.Tensor basic_op_extra_inputs: list of torch.Tensor
Extra tensor inputs to basic operations Extra tensor inputs to basic operations
basic_op_prev_ops: list of BasicOperation prev_op_grad_input_quantizer: Quantizer, optional
Basic operations that preceed this operation's basic The grad_input_quantizer of the preceeding operation
operations next_op_input_quantizer: Quantizer, optional
basic_op_next_ops: list of BasicOperation The input_quantizer of the following operation
Basic operations that follow this operation's basic is_first_op: bool
operations Does this op have a preceeding op or is it the first one in the
fuser. Used in the backward pass to safely delete the saved input
tensor when no longer needed and there is a preceeding op.
basic_op_kwargs: list of dict basic_op_kwargs: list of dict
Keyword arguments to forward functions of basic Keyword arguments to forward functions of basic
operations. operations.
...@@ -201,6 +210,16 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -201,6 +210,16 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
""" """
return 0 return 0
def get_input_quantizer(self) -> Optional[Quantizer]:
if self.num_quantizers("forward") > 0:
return self.get_quantizer("forward", 0)
return None
def get_grad_input_quantizer(self) -> Optional[Quantizer]:
if self.num_quantizers("backward") > 0:
return self.get_quantizer("backward", 0)
return None
def _reset_quantization_recipe_state( def _reset_quantization_recipe_state(
self, self,
*, *,
...@@ -407,8 +426,9 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -407,8 +426,9 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
*, *,
prev_op: Optional[BasicOperation] = None, prev_op_grad_input_quantizer: Optional[Quantizer],
next_op: Optional[BasicOperation] = None, next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
**kwargs: Any, **kwargs: Any,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass """Forward pass
...@@ -419,10 +439,14 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -419,10 +439,14 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
Context to coordinate between forward and backward passes Context to coordinate between forward and backward passes
input_: torch.Tensor input_: torch.Tensor
Input tensor Input tensor
prev_op: BasicOperation, optional prev_op_grad_input_quantizer: Quantizer, optional
Basic operation that preceeds this operation The grad_input_quantizer of the preceeding operation
next_op: BasicOperation, optional next_op_input_quantizer: Quantizer, optional
Basic operation that follows this operation The input_quantizer of the following operation
is_first_op: bool
Does this op have a preceeding op or is it the first one in the
fuser. Used in the backward pass to safely delete the saved input
tensor when no longer needed and there is a preceeding op.
Returns Returns
------- -------
...@@ -461,8 +485,9 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -461,8 +485,9 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
input_: torch.Tensor, input_: torch.Tensor,
*, *,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]], prev_op_grad_input_quantizer: Optional[Quantizer],
basic_op_next_ops: list[Optional[BasicOperation]], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, list[tuple[()]]]: ) -> tuple[torch.Tensor, list[tuple[()]]]:
if self.num_extra_inputs > 0 or self.num_extra_outputs > 0: if self.num_extra_inputs > 0 or self.num_extra_outputs > 0:
...@@ -475,8 +500,9 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -475,8 +500,9 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
output = self.op_forward( output = self.op_forward(
basic_op_ctxs[0], basic_op_ctxs[0],
input_, input_,
prev_op=basic_op_prev_ops[0], prev_op_grad_input_quantizer=prev_op_grad_input_quantizer,
next_op=basic_op_next_ops[0], next_op_input_quantizer=next_op_input_quantizer,
is_first_op=is_first_op,
**basic_op_kwargs[0], **basic_op_kwargs[0],
) )
return output, [()] return output, [()]
...@@ -696,6 +722,12 @@ class FusedOperation(FusibleOperation): ...@@ -696,6 +722,12 @@ class FusedOperation(FusibleOperation):
def is_fused_op(self) -> bool: def is_fused_op(self) -> bool:
return True return True
def get_input_quantizer(self) -> Optional[Quantizer]:
return self.basic_ops[0].get_input_quantizer()
def get_grad_input_quantizer(self) -> Optional[Quantizer]:
return self.basic_ops[-1].get_grad_input_quantizer()
def pre_forward(self) -> None: def pre_forward(self) -> None:
"""Preprocessing before forward pass""" """Preprocessing before forward pass"""
for op in self.basic_ops: for op in self.basic_ops:
......
...@@ -143,6 +143,19 @@ class Float8TensorBase(QuantizedTensorBase): ...@@ -143,6 +143,19 @@ class Float8TensorBase(QuantizedTensorBase):
size = self._transpose.size(*args, **kwargs) size = self._transpose.size(*args, **kwargs)
return torch.Size([size[-1], math.prod(size[:-1])]) return torch.Size([size[-1], math.prod(size[:-1])])
def view(self, shape: torch.Size):
# pylint: disable=missing-function-docstring
data = self._data
if data is not None:
return Float8TensorBase(
data=data.view(shape),
fp8_scale_inv=self._scale_inv,
fp8_dtype=self._fp8_dtype,
data_transpose=None,
quantizer=self._quantizer,
)
raise RuntimeError("No data available to view")
def __repr__(self): def __repr__(self):
return ( return (
"Float8TensorBase(" "Float8TensorBase("
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment