Unverified Commit ff760a9d authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Support pickling Float8Tensor (#529)



* Float8Tensor uses cached transpose if available
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug with non-2D transpose
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Custom pickling for Float8Tensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug test for pickling Float8Tensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix merge conflict
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestions from @sudhakarsingh27

Avoid FP8 casts when copying between Float8Tensors. Make make_like a class function.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add unit test for checkpointing model with FP8 params

Debugged pickling and copy functions.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 14c51e62
......@@ -3,6 +3,7 @@
# See LICENSE for license information.
from collections.abc import Iterable
import io
from typing import Any, Dict, List, Tuple, Union
import pytest
......@@ -263,7 +264,7 @@ class TestFloat8Tensor:
dims: DimsType,
transpose_dims: Tuple[int, int],
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 1,
scale: float = 0.5,
dtype: torch.dtype = torch.float32,
) -> None:
"""Test transpose"""
......@@ -316,3 +317,45 @@ class TestFloat8Tensor:
x_ref.transpose(*transpose_dims),
**tols,
)
def test_serialization(
self,
dims: DimsType = [2,3,5],
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 0.5,
dtype: torch.dtype = torch.float32,
):
# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_ref = x_fp8.from_float8()
# Serialize tensor
byte_stream = io.BytesIO()
torch.save(x_fp8, byte_stream)
x_bytes = byte_stream.getvalue()
# Mess up and delete old tensor
x_fp8._data.zero_()
x_fp8._scale_inv.zero_()
del x_fp8, byte_stream
# Deserialize tensor
x_fp8 = torch.load(io.BytesIO(x_bytes))
del x_bytes
# Check results
tols = dict(rtol=0, atol=0)
torch.testing.assert_close(x_fp8, x_ref, **tols)
# Make sure we are not trivially passing tests
x_fp8._data.zero_()
x_fp8._scale_inv.zero_()
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8, x_ref, **tols)
......@@ -11,15 +11,22 @@ The test verifies the values of FP8 metadata object after saving and loading a c
are identical to the original values.
"""
import io
import tempfile
from typing import Iterable, Union
import pytest
import torch
import transformer_engine.pytorch as te
import transformer_engine_extensions as tex
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.module.base import get_workspace
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
def init_meta(size: int=1):
meta = tex.FP8TensorMeta()
......@@ -29,16 +36,13 @@ def init_meta(size: int=1):
return meta
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("scale_fwd", [224, 112, 66])
@pytest.mark.parametrize("scale_bwd", [448, 33])
@pytest.mark.parametrize("history_fwd", [1.23, 4.56])
@pytest.mark.parametrize("history_bwd", [2.34, 5.67])
def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd):
# Skip FP8 tests on non-hopper devices
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
pytest.skip("Device compute capability 9.x required for FP8 execution.")
tmp_filename = tempfile.NamedTemporaryFile().name
precision = torch.float32
......@@ -118,3 +122,113 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd
assert torch.allclose(model_in.fp8_meta["scaling_bwd"].scale_inv, model_out.fp8_meta["scaling_bwd"].scale_inv)
assert torch.allclose(model_in.fp8_meta["scaling_bwd"].amax_history, model_out.fp8_meta["scaling_bwd"].amax_history)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("save_fp8_model", [True, False])
@pytest.mark.parametrize("load_fp8_model", [True, False])
def test_fp8_model_checkpoint(
save_fp8_model: bool,
load_fp8_model: bool,
dims: Iterable[int] = [32,32],
dtype: torch.dtype = torch.float32,
device: Union[torch.device, str] = "cuda",
):
# Construct model
dims = list(dims)
hidden_dim = dims[-1]
with te.fp8_model_init(enabled=save_fp8_model):
model = te.Linear(
hidden_dim,
hidden_dim,
bias=False,
params_dtype=dtype,
device=device,
)
# Keep track of model output
x = torch.randn(dims, dtype=dtype, device=device)
with te.fp8_autocast():
y_ref = model(x.detach().clone()).detach().clone()
# Keep track of weights and FP8 scaling factors
weight_ref = model.weight.float().detach().clone()
fp8_meta_ref = { "scaling_fwd": {}, "scaling_bwd": {} }
with te.fp8_autocast(), torch.no_grad():
fp8_meta_fwd = model.fp8_meta["scaling_fwd"]
fp8_meta_bwd = model.fp8_meta["scaling_bwd"]
fp8_meta_fwd_ref = fp8_meta_ref["scaling_fwd"]
fp8_meta_bwd_ref = fp8_meta_ref["scaling_bwd"]
fp8_meta_fwd_ref["scale"] = torch.rand_like(fp8_meta_fwd.scale) + 0.5
fp8_meta_fwd_ref["scale_inv"] = fp8_meta_fwd_ref["scale"].reciprocal()
fp8_meta_bwd_ref["scale"] = torch.rand_like(fp8_meta_bwd.scale) + 0.5
fp8_meta_bwd_ref["scale_inv"] = fp8_meta_bwd_ref["scale"].reciprocal()
fp8_meta_fwd.scale.copy_(fp8_meta_fwd_ref["scale"])
fp8_meta_fwd.scale_inv.copy_(fp8_meta_fwd_ref["scale_inv"])
fp8_meta_bwd.scale.copy_(fp8_meta_bwd_ref["scale"])
fp8_meta_bwd.scale_inv.copy_(fp8_meta_bwd_ref["scale_inv"])
del fp8_meta_fwd, fp8_meta_bwd
# Save checkpoint
byte_stream = io.BytesIO()
torch.save(model.state_dict(), byte_stream)
model_bytes = byte_stream.getvalue()
del byte_stream
# Disturb and destroy model
with torch.no_grad():
model.weight.zero_()
model.fp8_meta = {"This": "is", "filled": "with", "nonsense": 1234}
del model
# Construct new model
with te.fp8_model_init(enabled=load_fp8_model):
model = te.Linear(
hidden_dim,
hidden_dim,
bias=False,
params_dtype=dtype,
device=device,
)
# Make sure new model does not match saved model
tols = dict(rtol=0.125, atol=0.0675) # fp8e4me3 epsilon = 0.0625
with pytest.raises(AssertionError):
torch.testing.assert_close(model.weight, weight_ref, **tols)
with te.fp8_autocast():
model.init_fp8_metadata()
fp8_meta_fwd = model.fp8_meta["scaling_fwd"]
fp8_meta_bwd = model.fp8_meta["scaling_bwd"]
fp8_meta_fwd_ref = fp8_meta_ref["scaling_fwd"]
fp8_meta_bwd_ref = fp8_meta_ref["scaling_bwd"]
with pytest.raises(AssertionError):
torch.testing.assert_close(fp8_meta_fwd.scale, fp8_meta_fwd_ref["scale"])
with pytest.raises(AssertionError):
torch.testing.assert_close(fp8_meta_fwd.scale_inv, fp8_meta_fwd_ref["scale_inv"])
with pytest.raises(AssertionError):
torch.testing.assert_close(fp8_meta_bwd.scale, fp8_meta_bwd_ref["scale"])
with pytest.raises(AssertionError):
torch.testing.assert_close(fp8_meta_bwd.scale_inv, fp8_meta_bwd_ref["scale_inv"])
with te.fp8_autocast():
y = model(x.detach().clone())
with pytest.raises(AssertionError):
torch.testing.assert_close(y, y_ref, **tols)
# Load checkpoint
model.load_state_dict(torch.load(io.BytesIO(model_bytes)))
del model_bytes
# Check that loaded model matches saved model
torch.testing.assert_close(model.weight, weight_ref, **tols)
with te.fp8_autocast():
fp8_meta_fwd = model.fp8_meta["scaling_fwd"]
fp8_meta_bwd = model.fp8_meta["scaling_bwd"]
fp8_meta_fwd_ref = fp8_meta_ref["scaling_fwd"]
fp8_meta_bwd_ref = fp8_meta_ref["scaling_bwd"]
torch.testing.assert_close(fp8_meta_fwd.scale, fp8_meta_fwd_ref["scale"])
torch.testing.assert_close(fp8_meta_fwd.scale_inv, fp8_meta_fwd_ref["scale_inv"])
torch.testing.assert_close(fp8_meta_bwd.scale, fp8_meta_bwd_ref["scale"])
torch.testing.assert_close(fp8_meta_bwd.scale_inv, fp8_meta_bwd_ref["scale_inv"])
with te.fp8_autocast():
y = model(x.detach().clone())
torch.testing.assert_close(y, y_ref, **tols)
......@@ -435,30 +435,12 @@ class Float8Tensor(torch.Tensor):
return _IdentityFunc.apply(self)
return super().expand_as(other)
def _transpose_no_cache(self) -> torch.Tensor:
"""
Swap tensor dimensions
For basic 2D matrix transposes, an optimized transpose kernel
is applied and a Float8Tensor is returned.
"""
# Use optimized kernel for basic 2D transpose
# TODO Support differentiation # pylint: disable=fixme
return Float8Tensor.make_like(
self,
data=tex.fp8_transpose(
self._data.contiguous().detach(),
self._fp8_dtype,
),
)
def transpose(
self,
dim0: int = 0,
dim1: int = 1,
*,
update_cache: Optional[bool] = None,
update_cache: bool = False,
) -> torch.Tensor:
"""
Swap tensor dimensions
......@@ -472,12 +454,14 @@ class Float8Tensor(torch.Tensor):
The first dimension to be transposed
dim1: int, default = 1
The second dimension to be transposed
update_cache: Optional[bool], default = None
If set to `True`, the result is computed and stored in a cache.
If set to `False`, the result is computed only if the cache is
empty, otherwise the cache is returned. If set to `None`, the
result is not cached. Caching is only supported for basic 2D
transposes and the cache is reset after any in-place operations.
update_cache: bool, default = False
If `True`, the transpose is computed and stored
in a cache. If `False`, a cached version is
returned if available and otherwise the
transpose is computed. Caching is only supported
for basic 2D transposes and the cache is reset
after any in-place operations.
"""
# Handle non-2D transposes
......@@ -486,22 +470,32 @@ class Float8Tensor(torch.Tensor):
if -self.dim() <= dim1 < 0:
dim1 += self.dim()
if self.dim() != 2 or dim0 == dim1:
if update_cache is not None:
if update_cache:
raise ValueError(
"Transpose caching is only supported for basic 2D transposes "
f"(ndims={self.dim()}, dim0={dim0}, dim1={dim1})"
)
return super().transpose(dim0, dim1)
# No caching.
if update_cache is None:
return self._transpose_no_cache()
# Clear cache if needed
if update_cache:
self._transpose = None
# Update cache.
if update_cache or self._transpose is None:
self._transpose = self._transpose_no_cache()
# Compute transpose if needed
out = self._transpose
if out is None:
out = Float8Tensor.make_like(
self,
data=tex.fp8_transpose(
self._data.contiguous(),
self._fp8_dtype,
),
)
return self._transpose
# Update cache if needed
if update_cache:
self._transpose = out
return out
@torch.no_grad()
def reset_fp8_meta_scale_inv(self) -> None:
......@@ -550,16 +544,35 @@ class Float8Tensor(torch.Tensor):
# Check tensors
dst = args[0]
src = args[1]
if not isinstance(dst, Float8Tensor):
raise RuntimeError("Expected to copy into Float8Tensor")
if not isinstance(dst, torch.Tensor):
raise RuntimeError(
"Attempted to copy into something that isn't a PyTorch tensor"
)
if not isinstance(src, torch.Tensor):
raise RuntimeError("Expected to copy from tensor")
if not dst._data.is_contiguous():
raise RuntimeError("Transformer Engine cast kernels require contiguous data")
raise RuntimeError(
"Attempted to copy from something that isn't a PyTorch tensor"
)
# Special handling based on which tensors are FP8
dst_is_fp8 = isinstance(dst, Float8Tensor)
src_is_fp8 = isinstance(src, Float8Tensor)
if dst_is_fp8 and src_is_fp8:
# Directly copy FP8 data if possible
if dst._fp8_dtype == src._fp8_dtype:
dst._data.copy_(src._data)
dst._scale_inv = src._scale_inv.clone()
else:
dst.copy_(src.from_float8())
elif not dst_is_fp8 and src_is_fp8:
# Cast source tensor to higher precision
dst.copy_(src.from_float8())
elif dst_is_fp8 and not src_is_fp8:
# Make sure input is in expected format
if isinstance(src, Float8Tensor):
src = src.from_float8()
src = src.expand(dst.size())
src = src.to(
device=dst.device,
......@@ -577,6 +590,8 @@ class Float8Tensor(torch.Tensor):
dst._scale_inv = scale.detach().view(1).reciprocal()
# Cast to FP8
if not dst._data.is_contiguous():
raise RuntimeError("Transformer Engine cast kernels require contiguous data")
tex.cast_to_fp8_noalloc(
src.view(1,-1),
scale,
......@@ -586,7 +601,13 @@ class Float8Tensor(torch.Tensor):
dst._fp8_dtype,
)
else:
# Invalid case
raise RuntimeError("Using Float8Tensor copy logic, but no Float8Tensor found")
# Nothing to return for in-place ops
if dst_is_fp8:
dst._reset_caches()
return None
......@@ -658,6 +679,34 @@ class Float8Tensor(torch.Tensor):
out = super().__torch_dispatch__(func, types, args, kwargs)
return out
@classmethod
def _make_in_reduce_ex(
cls,
data: torch.Tensor,
fp8_dtype: tex.DType,
fp8_scale_inv: torch.Tensor,
dtype: torch.dtype,
) -> Float8Tensor:
"""Build Float8Tensor, for use in __reduce__
__reduce_ex__ assumes object constructor has positional
arguments.
"""
return Float8Tensor(
data=data,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
dtype=dtype,
)
def __reduce_ex__(self, protocol: int) -> tuple:
"""Custom pickling to remove references to FP8 metadata objects"""
return (
Float8Tensor._make_in_reduce_ex,
(self._data, self._fp8_dtype, self._scale_inv, self.dtype),
)
def _get_data(self) -> Float8Tensor:
"""Get tensor data property"""
return super().data
......
......@@ -819,19 +819,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
)
return fp8_weight_tensors
def state_dict(self, *args, **kwargs) -> Dict:
"""Get dictionary containing module state"""
state = super().state_dict(*args, **kwargs)
# Convert Float8Tensors to plain tensors
# Note: Float8Tensors don't serialize well, especially if they
# contain references to FP8 metadata.
for key, val in state.items():
if isinstance(val, Float8Tensor):
state[key] = val.from_float8()
return state
@abstractmethod
def forward(self):
"""Needs override."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment