Unverified Commit 04642bf8 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Add option in activation ops to cache input in FP8 (#1665)



* Add option to cache activation input in FP8
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid casting to FP8 transpose
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Skip input caching if device is not supported
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add documentation that FP8 input caching is experimental
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent dfb3c486
...@@ -1394,6 +1394,7 @@ class TestBasicOps: ...@@ -1394,6 +1394,7 @@ class TestBasicOps:
@pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32))) @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32)))
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("cache_quantized_input", (False, True))
def test_activation( def test_activation(
self, self,
*, *,
...@@ -1402,6 +1403,7 @@ class TestBasicOps: ...@@ -1402,6 +1403,7 @@ class TestBasicOps:
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device = "cuda", device: torch.device = "cuda",
quantization: Optional[str], quantization: Optional[str],
cache_quantized_input: bool,
) -> None: ) -> None:
"""Activation functions""" """Activation functions"""
...@@ -1413,6 +1415,8 @@ class TestBasicOps: ...@@ -1413,6 +1415,8 @@ class TestBasicOps:
# Skip invalid configurations # Skip invalid configurations
quantized_compute = quantization is not None quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device)
if cache_quantized_input:
maybe_skip_quantization("fp8", device=device)
# Random data # Random data
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
...@@ -1463,7 +1467,7 @@ class TestBasicOps: ...@@ -1463,7 +1467,7 @@ class TestBasicOps:
)[activation] )[activation]
forward = te_ops.Sequential( forward = te_ops.Sequential(
te_ops.Quantize(forward=False, backward=quantized_compute), te_ops.Quantize(forward=False, backward=quantized_compute),
make_op(), make_op(cache_quantized_input=cache_quantized_input),
te_ops.Quantize(forward=quantized_compute, backward=False), te_ops.Quantize(forward=quantized_compute, backward=False),
) )
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
...@@ -1472,9 +1476,9 @@ class TestBasicOps: ...@@ -1472,9 +1476,9 @@ class TestBasicOps:
# Expected numerical error # Expected numerical error
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
if quantized_compute: if quantized_compute or cache_quantized_input:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = dtype_tols(tex.DType.kFloat8E4M3)
if activation == "relu": if activation == "relu" and not cache_quantized_input:
tols = {"atol": 0, "rtol": 0} tols = {"atol": 0, "rtol": 0}
# Check results # Check results
......
...@@ -13,6 +13,7 @@ import torch ...@@ -13,6 +13,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ...fp8 import FP8GlobalStateManager from ...fp8 import FP8GlobalStateManager
from ...tensor import QuantizedTensor from ...tensor import QuantizedTensor
from ...tensor.float8_tensor import Float8CurrentScalingQuantizer
from ...utils import clear_tensor_data, devices_match from ...utils import clear_tensor_data, devices_match
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
from .._common import reshape from .._common import reshape
...@@ -37,8 +38,20 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -37,8 +38,20 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
the first half of the input tensor, while PyTorch applies it to the first half of the input tensor, while PyTorch applies it to
the second half. the second half.
Parameters
----------
cache_quantized_input: bool, default = False
Quantize input tensor when caching for use in the backward
pass. This will typically reduce memory usage but require
extra compute and increase numerical error. This feature is
highly experimental.
""" """
def __init__(self, *, cache_quantized_input: bool = False):
super().__init__()
self.cache_quantized_input: bool = cache_quantized_input
@abc.abstractmethod @abc.abstractmethod
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
"""Forward implementation """Forward implementation
...@@ -100,9 +113,16 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -100,9 +113,16 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
if y.dim() != x.dim(): if y.dim() != x.dim():
y = y.reshape(list(x.shape[:-1]) + [-1]) y = y.reshape(list(x.shape[:-1]) + [-1])
# Quantize input to FP8 before caching if needed
if self.cache_quantized_input:
quantizer = Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, x.device)
quantizer.set_usage(rowwise=True, columnwise=False)
x = quantizer(x)
# Save state for backward pass # Save state for backward pass
ctx.save_for_backward(x.detach()) ctx.save_for_backward(x.detach())
ctx.fp8_enabled = fp8_enabled ctx.fp8_enabled = fp8_enabled
ctx.dtype = dtype
ctx.prev_op = prev_op ctx.prev_op = prev_op
return y return y
...@@ -116,10 +136,18 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -116,10 +136,18 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
# Saved tensors from forward pass # Saved tensors from forward pass
(x,) = ctx.saved_tensors (x,) = ctx.saved_tensors
# Check input tensor
if isinstance(x, QuantizedTensor):
x = x.dequantize(dtype=ctx.dtype)
elif x.dtype != ctx.dtype:
x = x.to(dtype=ctx.dtype)
if not x.is_contiguous():
x = x.contiguous()
# Check grad output tensor # Check grad output tensor
dy = grad_output dy = grad_output
if isinstance(dy, QuantizedTensor): if isinstance(dy, QuantizedTensor):
dy = dy.dequantize() dy = dy.dequantize(dtype=ctx.dtype)
if not devices_match(dy.device, x.device) or dy.dtype != x.dtype: if not devices_match(dy.device, x.device) or dy.dtype != x.dtype:
dy = dy.to(device=x.device, dtype=x.dtype) dy = dy.to(device=x.device, dtype=x.dtype)
if not dy.is_contiguous(): if not dy.is_contiguous():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment