Unverified Commit ff760a9d authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Support pickling Float8Tensor (#529)



* Float8Tensor uses cached transpose if available
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug with non-2D transpose
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Custom pickling for Float8Tensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug test for pickling Float8Tensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix merge conflict
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestions from @sudhakarsingh27

Avoid FP8 casts when copying between Float8Tensors. Make make_like a class function.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add unit test for checkpointing model with FP8 params

Debugged pickling and copy functions.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 14c51e62
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
from collections.abc import Iterable from collections.abc import Iterable
import io
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Tuple, Union
import pytest import pytest
...@@ -263,7 +264,7 @@ class TestFloat8Tensor: ...@@ -263,7 +264,7 @@ class TestFloat8Tensor:
dims: DimsType, dims: DimsType,
transpose_dims: Tuple[int, int], transpose_dims: Tuple[int, int],
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 1, scale: float = 0.5,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
) -> None: ) -> None:
"""Test transpose""" """Test transpose"""
...@@ -316,3 +317,45 @@ class TestFloat8Tensor: ...@@ -316,3 +317,45 @@ class TestFloat8Tensor:
x_ref.transpose(*transpose_dims), x_ref.transpose(*transpose_dims),
**tols, **tols,
) )
def test_serialization(
self,
dims: DimsType = [2,3,5],
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 0.5,
dtype: torch.dtype = torch.float32,
):
# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_ref = x_fp8.from_float8()
# Serialize tensor
byte_stream = io.BytesIO()
torch.save(x_fp8, byte_stream)
x_bytes = byte_stream.getvalue()
# Mess up and delete old tensor
x_fp8._data.zero_()
x_fp8._scale_inv.zero_()
del x_fp8, byte_stream
# Deserialize tensor
x_fp8 = torch.load(io.BytesIO(x_bytes))
del x_bytes
# Check results
tols = dict(rtol=0, atol=0)
torch.testing.assert_close(x_fp8, x_ref, **tols)
# Make sure we are not trivially passing tests
x_fp8._data.zero_()
x_fp8._scale_inv.zero_()
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8, x_ref, **tols)
...@@ -11,15 +11,22 @@ The test verifies the values of FP8 metadata object after saving and loading a c ...@@ -11,15 +11,22 @@ The test verifies the values of FP8 metadata object after saving and loading a c
are identical to the original values. are identical to the original values.
""" """
import io
import tempfile import tempfile
from typing import Iterable, Union
import pytest import pytest
import torch import torch
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8 from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.module.base import get_workspace from transformer_engine.pytorch.module.base import get_workspace
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
def init_meta(size: int=1): def init_meta(size: int=1):
meta = tex.FP8TensorMeta() meta = tex.FP8TensorMeta()
...@@ -29,16 +36,13 @@ def init_meta(size: int=1): ...@@ -29,16 +36,13 @@ def init_meta(size: int=1):
return meta return meta
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("scale_fwd", [224, 112, 66]) @pytest.mark.parametrize("scale_fwd", [224, 112, 66])
@pytest.mark.parametrize("scale_bwd", [448, 33]) @pytest.mark.parametrize("scale_bwd", [448, 33])
@pytest.mark.parametrize("history_fwd", [1.23, 4.56]) @pytest.mark.parametrize("history_fwd", [1.23, 4.56])
@pytest.mark.parametrize("history_bwd", [2.34, 5.67]) @pytest.mark.parametrize("history_bwd", [2.34, 5.67])
def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd): def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd):
# Skip FP8 tests on non-hopper devices
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
pytest.skip("Device compute capability 9.x required for FP8 execution.")
tmp_filename = tempfile.NamedTemporaryFile().name tmp_filename = tempfile.NamedTemporaryFile().name
precision = torch.float32 precision = torch.float32
...@@ -118,3 +122,113 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd ...@@ -118,3 +122,113 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd
assert torch.allclose(model_in.fp8_meta["scaling_bwd"].scale_inv, model_out.fp8_meta["scaling_bwd"].scale_inv) assert torch.allclose(model_in.fp8_meta["scaling_bwd"].scale_inv, model_out.fp8_meta["scaling_bwd"].scale_inv)
assert torch.allclose(model_in.fp8_meta["scaling_bwd"].amax_history, model_out.fp8_meta["scaling_bwd"].amax_history) assert torch.allclose(model_in.fp8_meta["scaling_bwd"].amax_history, model_out.fp8_meta["scaling_bwd"].amax_history)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("save_fp8_model", [True, False])
@pytest.mark.parametrize("load_fp8_model", [True, False])
def test_fp8_model_checkpoint(
save_fp8_model: bool,
load_fp8_model: bool,
dims: Iterable[int] = [32,32],
dtype: torch.dtype = torch.float32,
device: Union[torch.device, str] = "cuda",
):
# Construct model
dims = list(dims)
hidden_dim = dims[-1]
with te.fp8_model_init(enabled=save_fp8_model):
model = te.Linear(
hidden_dim,
hidden_dim,
bias=False,
params_dtype=dtype,
device=device,
)
# Keep track of model output
x = torch.randn(dims, dtype=dtype, device=device)
with te.fp8_autocast():
y_ref = model(x.detach().clone()).detach().clone()
# Keep track of weights and FP8 scaling factors
weight_ref = model.weight.float().detach().clone()
fp8_meta_ref = { "scaling_fwd": {}, "scaling_bwd": {} }
with te.fp8_autocast(), torch.no_grad():
fp8_meta_fwd = model.fp8_meta["scaling_fwd"]
fp8_meta_bwd = model.fp8_meta["scaling_bwd"]
fp8_meta_fwd_ref = fp8_meta_ref["scaling_fwd"]
fp8_meta_bwd_ref = fp8_meta_ref["scaling_bwd"]
fp8_meta_fwd_ref["scale"] = torch.rand_like(fp8_meta_fwd.scale) + 0.5
fp8_meta_fwd_ref["scale_inv"] = fp8_meta_fwd_ref["scale"].reciprocal()
fp8_meta_bwd_ref["scale"] = torch.rand_like(fp8_meta_bwd.scale) + 0.5
fp8_meta_bwd_ref["scale_inv"] = fp8_meta_bwd_ref["scale"].reciprocal()
fp8_meta_fwd.scale.copy_(fp8_meta_fwd_ref["scale"])
fp8_meta_fwd.scale_inv.copy_(fp8_meta_fwd_ref["scale_inv"])
fp8_meta_bwd.scale.copy_(fp8_meta_bwd_ref["scale"])
fp8_meta_bwd.scale_inv.copy_(fp8_meta_bwd_ref["scale_inv"])
del fp8_meta_fwd, fp8_meta_bwd
# Save checkpoint
byte_stream = io.BytesIO()
torch.save(model.state_dict(), byte_stream)
model_bytes = byte_stream.getvalue()
del byte_stream
# Disturb and destroy model
with torch.no_grad():
model.weight.zero_()
model.fp8_meta = {"This": "is", "filled": "with", "nonsense": 1234}
del model
# Construct new model
with te.fp8_model_init(enabled=load_fp8_model):
model = te.Linear(
hidden_dim,
hidden_dim,
bias=False,
params_dtype=dtype,
device=device,
)
# Make sure new model does not match saved model
tols = dict(rtol=0.125, atol=0.0675) # fp8e4me3 epsilon = 0.0625
with pytest.raises(AssertionError):
torch.testing.assert_close(model.weight, weight_ref, **tols)
with te.fp8_autocast():
model.init_fp8_metadata()
fp8_meta_fwd = model.fp8_meta["scaling_fwd"]
fp8_meta_bwd = model.fp8_meta["scaling_bwd"]
fp8_meta_fwd_ref = fp8_meta_ref["scaling_fwd"]
fp8_meta_bwd_ref = fp8_meta_ref["scaling_bwd"]
with pytest.raises(AssertionError):
torch.testing.assert_close(fp8_meta_fwd.scale, fp8_meta_fwd_ref["scale"])
with pytest.raises(AssertionError):
torch.testing.assert_close(fp8_meta_fwd.scale_inv, fp8_meta_fwd_ref["scale_inv"])
with pytest.raises(AssertionError):
torch.testing.assert_close(fp8_meta_bwd.scale, fp8_meta_bwd_ref["scale"])
with pytest.raises(AssertionError):
torch.testing.assert_close(fp8_meta_bwd.scale_inv, fp8_meta_bwd_ref["scale_inv"])
with te.fp8_autocast():
y = model(x.detach().clone())
with pytest.raises(AssertionError):
torch.testing.assert_close(y, y_ref, **tols)
# Load checkpoint
model.load_state_dict(torch.load(io.BytesIO(model_bytes)))
del model_bytes
# Check that loaded model matches saved model
torch.testing.assert_close(model.weight, weight_ref, **tols)
with te.fp8_autocast():
fp8_meta_fwd = model.fp8_meta["scaling_fwd"]
fp8_meta_bwd = model.fp8_meta["scaling_bwd"]
fp8_meta_fwd_ref = fp8_meta_ref["scaling_fwd"]
fp8_meta_bwd_ref = fp8_meta_ref["scaling_bwd"]
torch.testing.assert_close(fp8_meta_fwd.scale, fp8_meta_fwd_ref["scale"])
torch.testing.assert_close(fp8_meta_fwd.scale_inv, fp8_meta_fwd_ref["scale_inv"])
torch.testing.assert_close(fp8_meta_bwd.scale, fp8_meta_bwd_ref["scale"])
torch.testing.assert_close(fp8_meta_bwd.scale_inv, fp8_meta_bwd_ref["scale_inv"])
with te.fp8_autocast():
y = model(x.detach().clone())
torch.testing.assert_close(y, y_ref, **tols)
...@@ -435,30 +435,12 @@ class Float8Tensor(torch.Tensor): ...@@ -435,30 +435,12 @@ class Float8Tensor(torch.Tensor):
return _IdentityFunc.apply(self) return _IdentityFunc.apply(self)
return super().expand_as(other) return super().expand_as(other)
def _transpose_no_cache(self) -> torch.Tensor:
"""
Swap tensor dimensions
For basic 2D matrix transposes, an optimized transpose kernel
is applied and a Float8Tensor is returned.
"""
# Use optimized kernel for basic 2D transpose
# TODO Support differentiation # pylint: disable=fixme
return Float8Tensor.make_like(
self,
data=tex.fp8_transpose(
self._data.contiguous().detach(),
self._fp8_dtype,
),
)
def transpose( def transpose(
self, self,
dim0: int = 0, dim0: int = 0,
dim1: int = 1, dim1: int = 1,
*, *,
update_cache: Optional[bool] = None, update_cache: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Swap tensor dimensions Swap tensor dimensions
...@@ -472,12 +454,14 @@ class Float8Tensor(torch.Tensor): ...@@ -472,12 +454,14 @@ class Float8Tensor(torch.Tensor):
The first dimension to be transposed The first dimension to be transposed
dim1: int, default = 1 dim1: int, default = 1
The second dimension to be transposed The second dimension to be transposed
update_cache: Optional[bool], default = None update_cache: bool, default = False
If set to `True`, the result is computed and stored in a cache. If `True`, the transpose is computed and stored
If set to `False`, the result is computed only if the cache is in a cache. If `False`, a cached version is
empty, otherwise the cache is returned. If set to `None`, the returned if available and otherwise the
result is not cached. Caching is only supported for basic 2D transpose is computed. Caching is only supported
transposes and the cache is reset after any in-place operations. for basic 2D transposes and the cache is reset
after any in-place operations.
""" """
# Handle non-2D transposes # Handle non-2D transposes
...@@ -486,22 +470,32 @@ class Float8Tensor(torch.Tensor): ...@@ -486,22 +470,32 @@ class Float8Tensor(torch.Tensor):
if -self.dim() <= dim1 < 0: if -self.dim() <= dim1 < 0:
dim1 += self.dim() dim1 += self.dim()
if self.dim() != 2 or dim0 == dim1: if self.dim() != 2 or dim0 == dim1:
if update_cache is not None: if update_cache:
raise ValueError( raise ValueError(
"Transpose caching is only supported for basic 2D transposes " "Transpose caching is only supported for basic 2D transposes "
f"(ndims={self.dim()}, dim0={dim0}, dim1={dim1})" f"(ndims={self.dim()}, dim0={dim0}, dim1={dim1})"
) )
return super().transpose(dim0, dim1) return super().transpose(dim0, dim1)
# No caching. # Clear cache if needed
if update_cache is None: if update_cache:
return self._transpose_no_cache() self._transpose = None
# Update cache. # Compute transpose if needed
if update_cache or self._transpose is None: out = self._transpose
self._transpose = self._transpose_no_cache() if out is None:
out = Float8Tensor.make_like(
self,
data=tex.fp8_transpose(
self._data.contiguous(),
self._fp8_dtype,
),
)
return self._transpose # Update cache if needed
if update_cache:
self._transpose = out
return out
@torch.no_grad() @torch.no_grad()
def reset_fp8_meta_scale_inv(self) -> None: def reset_fp8_meta_scale_inv(self) -> None:
...@@ -550,16 +544,35 @@ class Float8Tensor(torch.Tensor): ...@@ -550,16 +544,35 @@ class Float8Tensor(torch.Tensor):
# Check tensors # Check tensors
dst = args[0] dst = args[0]
src = args[1] src = args[1]
if not isinstance(dst, Float8Tensor): if not isinstance(dst, torch.Tensor):
raise RuntimeError("Expected to copy into Float8Tensor") raise RuntimeError(
"Attempted to copy into something that isn't a PyTorch tensor"
)
if not isinstance(src, torch.Tensor): if not isinstance(src, torch.Tensor):
raise RuntimeError("Expected to copy from tensor") raise RuntimeError(
if not dst._data.is_contiguous(): "Attempted to copy from something that isn't a PyTorch tensor"
raise RuntimeError("Transformer Engine cast kernels require contiguous data") )
# Special handling based on which tensors are FP8
dst_is_fp8 = isinstance(dst, Float8Tensor)
src_is_fp8 = isinstance(src, Float8Tensor)
if dst_is_fp8 and src_is_fp8:
# Directly copy FP8 data if possible
if dst._fp8_dtype == src._fp8_dtype:
dst._data.copy_(src._data)
dst._scale_inv = src._scale_inv.clone()
else:
dst.copy_(src.from_float8())
elif not dst_is_fp8 and src_is_fp8:
# Cast source tensor to higher precision
dst.copy_(src.from_float8())
elif dst_is_fp8 and not src_is_fp8:
# Make sure input is in expected format # Make sure input is in expected format
if isinstance(src, Float8Tensor):
src = src.from_float8()
src = src.expand(dst.size()) src = src.expand(dst.size())
src = src.to( src = src.to(
device=dst.device, device=dst.device,
...@@ -577,6 +590,8 @@ class Float8Tensor(torch.Tensor): ...@@ -577,6 +590,8 @@ class Float8Tensor(torch.Tensor):
dst._scale_inv = scale.detach().view(1).reciprocal() dst._scale_inv = scale.detach().view(1).reciprocal()
# Cast to FP8 # Cast to FP8
if not dst._data.is_contiguous():
raise RuntimeError("Transformer Engine cast kernels require contiguous data")
tex.cast_to_fp8_noalloc( tex.cast_to_fp8_noalloc(
src.view(1,-1), src.view(1,-1),
scale, scale,
...@@ -586,7 +601,13 @@ class Float8Tensor(torch.Tensor): ...@@ -586,7 +601,13 @@ class Float8Tensor(torch.Tensor):
dst._fp8_dtype, dst._fp8_dtype,
) )
else:
# Invalid case
raise RuntimeError("Using Float8Tensor copy logic, but no Float8Tensor found")
# Nothing to return for in-place ops # Nothing to return for in-place ops
if dst_is_fp8:
dst._reset_caches() dst._reset_caches()
return None return None
...@@ -658,6 +679,34 @@ class Float8Tensor(torch.Tensor): ...@@ -658,6 +679,34 @@ class Float8Tensor(torch.Tensor):
out = super().__torch_dispatch__(func, types, args, kwargs) out = super().__torch_dispatch__(func, types, args, kwargs)
return out return out
@classmethod
def _make_in_reduce_ex(
cls,
data: torch.Tensor,
fp8_dtype: tex.DType,
fp8_scale_inv: torch.Tensor,
dtype: torch.dtype,
) -> Float8Tensor:
"""Build Float8Tensor, for use in __reduce__
__reduce_ex__ assumes object constructor has positional
arguments.
"""
return Float8Tensor(
data=data,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
dtype=dtype,
)
def __reduce_ex__(self, protocol: int) -> tuple:
"""Custom pickling to remove references to FP8 metadata objects"""
return (
Float8Tensor._make_in_reduce_ex,
(self._data, self._fp8_dtype, self._scale_inv, self.dtype),
)
def _get_data(self) -> Float8Tensor: def _get_data(self) -> Float8Tensor:
"""Get tensor data property""" """Get tensor data property"""
return super().data return super().data
......
...@@ -819,19 +819,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -819,19 +819,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
) )
return fp8_weight_tensors return fp8_weight_tensors
def state_dict(self, *args, **kwargs) -> Dict:
"""Get dictionary containing module state"""
state = super().state_dict(*args, **kwargs)
# Convert Float8Tensors to plain tensors
# Note: Float8Tensors don't serialize well, especially if they
# contain references to FP8 metadata.
for key, val in state.items():
if isinstance(val, Float8Tensor):
state[key] = val.from_float8()
return state
@abstractmethod @abstractmethod
def forward(self): def forward(self):
"""Needs override.""" """Needs override."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment