"vllm/vscode:/vscode.git/clone" did not exist on "c06170cc8e324f4fe6a0c26b57d09e8c958e11bc"
Unverified Commit 04642bf8 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Add option in activation ops to cache input in FP8 (#1665)



* Add option to cache activation input in FP8
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid casting to FP8 transpose
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Skip input caching if device is not supported
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add documentation that FP8 input caching is experimental
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent dfb3c486
......@@ -1394,6 +1394,7 @@ class TestBasicOps:
@pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("cache_quantized_input", (False, True))
def test_activation(
self,
*,
......@@ -1402,6 +1403,7 @@ class TestBasicOps:
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
cache_quantized_input: bool,
) -> None:
"""Activation functions"""
......@@ -1413,6 +1415,8 @@ class TestBasicOps:
# Skip invalid configurations
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
if cache_quantized_input:
maybe_skip_quantization("fp8", device=device)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
......@@ -1463,7 +1467,7 @@ class TestBasicOps:
)[activation]
forward = te_ops.Sequential(
te_ops.Quantize(forward=False, backward=quantized_compute),
make_op(),
make_op(cache_quantized_input=cache_quantized_input),
te_ops.Quantize(forward=quantized_compute, backward=False),
)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
......@@ -1472,9 +1476,9 @@ class TestBasicOps:
# Expected numerical error
tols = dtype_tols(dtype)
if quantized_compute:
if quantized_compute or cache_quantized_input:
tols = dtype_tols(tex.DType.kFloat8E4M3)
if activation == "relu":
if activation == "relu" and not cache_quantized_input:
tols = {"atol": 0, "rtol": 0}
# Check results
......
......@@ -13,6 +13,7 @@ import torch
import transformer_engine_torch as tex
from ...fp8 import FP8GlobalStateManager
from ...tensor import QuantizedTensor
from ...tensor.float8_tensor import Float8CurrentScalingQuantizer
from ...utils import clear_tensor_data, devices_match
from ..op import BasicOperation, OperationContext
from .._common import reshape
......@@ -37,8 +38,20 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
the first half of the input tensor, while PyTorch applies it to
the second half.
Parameters
----------
cache_quantized_input: bool, default = False
Quantize input tensor when caching for use in the backward
pass. This will typically reduce memory usage but require
extra compute and increase numerical error. This feature is
highly experimental.
"""
def __init__(self, *, cache_quantized_input: bool = False):
super().__init__()
self.cache_quantized_input: bool = cache_quantized_input
@abc.abstractmethod
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
"""Forward implementation
......@@ -100,9 +113,16 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
if y.dim() != x.dim():
y = y.reshape(list(x.shape[:-1]) + [-1])
# Quantize input to FP8 before caching if needed
if self.cache_quantized_input:
quantizer = Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, x.device)
quantizer.set_usage(rowwise=True, columnwise=False)
x = quantizer(x)
# Save state for backward pass
ctx.save_for_backward(x.detach())
ctx.fp8_enabled = fp8_enabled
ctx.dtype = dtype
ctx.prev_op = prev_op
return y
......@@ -116,10 +136,18 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
# Saved tensors from forward pass
(x,) = ctx.saved_tensors
# Check input tensor
if isinstance(x, QuantizedTensor):
x = x.dequantize(dtype=ctx.dtype)
elif x.dtype != ctx.dtype:
x = x.to(dtype=ctx.dtype)
if not x.is_contiguous():
x = x.contiguous()
# Check grad output tensor
dy = grad_output
if isinstance(dy, QuantizedTensor):
dy = dy.dequantize()
dy = dy.dequantize(dtype=ctx.dtype)
if not devices_match(dy.device, x.device) or dy.dtype != x.dtype:
dy = dy.to(device=x.device, dtype=x.dtype)
if not dy.is_contiguous():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment