Unverified Commit 2d57db8b authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Proxy class for low-precision tensor (#1127)



* Add base class for tensor proxies
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Move tensor detaching logic to tensor proxy base class
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use Python wrappers to PyTorch extensions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Include transpose caching logic in proxy encode function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug dimension mismatch with amax history
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Move dequantize logic to proxy_decode func
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Rename to "QuantizedTensor"
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Rename "proxy_detach" to "detach"
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Include transpose cache in detach and clone funcs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update FP8 workspaces with QuantizedTensor functions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Move logic for FP8 transpose cache in FP8 workspaces to base class
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove cast-transpose logic from linear op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove unnecessary args for Float8Tensor when using FP8 attr dict
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove __torch_function__ to QuantizedTensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update tests/pytorch/test_float8tensor.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Debug FP8 transpose test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug cast functions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 40dda924
......@@ -31,7 +31,7 @@ jobs:
run: |
sudo apt-get update
sudo apt-get install pip -y
pip install torch
pip install torch numpy
export PYTHON_ONLY=1
export TE_PATH=.
bash ./qa/L0_pytorch_lint/test.sh
......
......@@ -293,7 +293,7 @@ class TestFloat8Tensor:
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8_t, x, **tols)
# Caching test.
# Caching test
assert x_fp8._transpose_invalid, "Transpose cache must be invalid when not caching."
x_fp8 += 0.5
x = x_fp8.from_float8()
......@@ -302,14 +302,13 @@ class TestFloat8Tensor:
torch.testing.assert_close(x_fp8_t, x_t, **tols)
assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."
# Inplace update test.
# Inplace update test
x_fp8 += 0.5
assert x_fp8._transpose_invalid, "Transpose cache not invalidated properly."
assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."
x = x_fp8.from_float8()
x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(fill_cache=True))
x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8._transpose)
x_t = x.transpose(0, 1)
torch.testing.assert_close(x_fp8_t, x_t, **tols)
assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."
def test_serialization(
self,
......
......@@ -88,10 +88,7 @@ def make_reference_and_test_tensors(
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8:
test = Float8Tensor.to_float8(test)
test._transpose = test._data.reshape(-1, test.size(-1)).transpose(0, 1)
test._transpose = test._transpose.contiguous()
test._transpose_invalid = False
test = Float8Tensor.to_float8(test, with_transpose_cache=True)
elif test.data_ptr() == ref.data_ptr():
test = test.clone()
ref.copy_(test)
......
......@@ -68,13 +68,13 @@ def canonicalize_fp8_scales(
# Force offsets to be the same if needed
if not allow_multiple_offsets and not scale_offset == amax_offset == scale_inv_offset:
if scale_offset != 0:
scale = scale[scale_offset]
scale = scale[scale_offset:]
scale_offset = 0
if amax_offset != 0:
amax = amax[0][amax_offset]
amax = amax[:, amax_offset:]
amax_offset = 0
if scale_inv_offset != 0:
scale_inv = scale_inv[scale_inv_offset]
scale_inv = scale_inv[scale_inv_offset:]
scale_inv_offset = 0
# Pack tensors and offsets into dicts
......
......@@ -8,7 +8,7 @@ from typing import Optional, Union
import torch
import transformer_engine_torch as tex
from ._common import canonicalize_fp8_scales, empty_tensor
from ._common import canonicalize_fp8_scales
__all__ = ["cast_to_fp8", "cast_from_fp8"]
......@@ -81,8 +81,7 @@ def cast_from_fp8(
# Construct empty tensors if needed
if scale_inv is None:
scale_inv = empty_tensor()
scale_inv_offset = 0
raise ValueError("Did not provide either `scale_inv` or `fp8_meta_tensor`")
# Launch FP8 cast kernel
return torch.ops.tex_ts.cast_from_fp8_ts(
......
......@@ -3,1004 +3,7 @@
# See LICENSE for license information.
"""Tensor class with FP8 data"""
from __future__ import annotations
from typing import Any, Dict, Optional, Tuple, Union
import warnings
import torch
from torch.utils._pytree import tree_map
import transformer_engine_torch as tex
from .tensor import Float8Tensor
from .constants import TE_DType
from .cpp_extensions import fp8_cast_transpose_fused
from .fp8 import FP8GlobalStateManager
aten = torch.ops.aten
c10d = torch.ops.c10d
updated_fp8_params = {}
def _make_fp8_attr_property_funcs(name: str) -> Any:
"""Make accessors for an FP8 attribute
We store FP8 attributes in a dictionary so we can share them
between tensors with the same data, e.g. detached tensors. For
convenience, we also expose them as property attributes. This
function creates the accessors for property attributes.
Parameters
----------
name: str
Key in dictionary of FP8 attributes
"""
def get_func(self) -> Any:
return self._fp8_attrs[name]
def set_func(self, value: Any) -> None:
self._fp8_attrs[name] = value
def del_func(self) -> None:
del self._fp8_attrs[name]
return dict(fget=get_func, fset=set_func, fdel=del_func)
class _FromFloat8Func(torch.autograd.Function):
"""Cast from FP8 to other dtype"""
@staticmethod
def forward(
_ctx: torch.autograd.function.FunctionCtx, # unused
tensor: Float8Tensor,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
if dtype is None:
dtype = tensor.dtype
data = tensor._data.contiguous().view(1, -1).detach()
out = tex.cast_from_fp8(
data,
tensor._scale_inv,
tensor._fp8_dtype,
TE_DType[dtype],
)
out = out.view(tensor.size())
return out
@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# Assume that we want gradients in full precision
return grad, None
def post_optimizer_step_fwd_amax_reduction(param: Float8Tensor) -> None:
"""Amax scale and update when there is at least 1 trainable FP8 parameter."""
param_id = id(param._data)
if param_id not in FP8GlobalStateManager.fp8_param_to_autocast:
return
autocast_key = FP8GlobalStateManager.fp8_param_to_autocast[param_id]
if autocast_key not in FP8GlobalStateManager.autocast_to_fp8_params:
return
if autocast_key in updated_fp8_params:
updated_fp8_params[autocast_key].add(param_id)
else:
updated_fp8_params[autocast_key] = {param_id}
current_fp8_params_set = FP8GlobalStateManager.autocast_to_fp8_params[autocast_key]
# All FP8 trainable parameters have been updated.
if updated_fp8_params[autocast_key] == current_fp8_params_set:
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=True, fp8_weights=True)
del updated_fp8_params[autocast_key]
class _ToFloat8Func(torch.autograd.Function):
"""Cast to FP8 from other dtype"""
@staticmethod
def forward(
_ctx: torch.autograd.function.FunctionCtx, # unused
tensor: torch.Tensor,
fp8_meta: Optional[Dict[str, Any]] = None,
fp8_meta_forward: bool = True,
fp8_meta_index: Optional[int] = None,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> Float8Tensor:
# Extract data from FP8 meta tensors if provided
if fp8_meta is not None:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=fp8_meta_forward,
)
if fp8_meta_index is None:
raise ValueError(
"To initialize Float8Tensor with FP8 meta tensors, "
"the FP8 meta tensor index must also be provided"
)
if scale is None:
scale = fp8_meta[fp8_meta_key].scale[fp8_meta_index]
if amax is None:
amax = fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]
# Check input tensor
tensor = tensor.contiguous().cuda().detach()
if tensor.dtype not in (torch.float32, torch.bfloat16, torch.float16):
tensor = tensor.float()
# Check scale
if not isinstance(scale, torch.Tensor):
if scale is None:
scale = 1
scale = torch.full(
[1],
scale,
dtype=torch.float32,
device=tensor.device,
)
if scale.numel() != 1:
raise ValueError("Attempted to initialize Float8Tensor with invalid scale tensor")
scale = scale.to(device=tensor.device, dtype=torch.float32)
# Check scale-inverse
if scale_inv is None:
scale_inv = torch.empty([1], dtype=torch.float32, device=tensor.device)
else:
scale_inv = scale_inv.to(device=tensor.device, dtype=torch.float32)
# Check amax
if amax is None:
amax = torch.empty_like(scale)
if not (amax.numel() == 1 and amax.is_cuda and amax.dtype == torch.float32):
raise ValueError("Attempted to initialize Float8Tensor with invalid amax tensor")
# Cast data to FP8
data = tex.cast_to_fp8(
tensor.view(1, -1),
scale,
amax,
scale_inv,
fp8_dtype,
)
data = data.view(tensor.size())
# Construct FP8 tensor
return Float8Tensor(
data=data,
fp8_meta=fp8_meta,
fp8_meta_forward=fp8_meta_forward,
fp8_meta_index=fp8_meta_index,
fp8_dtype=fp8_dtype,
fp8_scale_inv=scale_inv,
dtype=tensor.dtype,
)
@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# Assume that we want gradients in full precision
return grad, None, None, None, None, None, None, None
class _IdentityFunc(torch.autograd.Function):
"""Identity function
If constructor keyword-arguments are provided, then construct a
new Float8Tensor using the provided tensor's attributes.
"""
@staticmethod
def forward(
ctx,
tensor: Float8Tensor,
init_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
# Return input tensor if constructor kwargs are not provided
ctx.input_dtype = tensor.dtype
if init_kwargs is None:
return tensor
# Construct new tensor if constructor kwargs are provided
default_kwargs = dict(
data=tensor._data,
fp8_meta=tensor._fp8_meta,
fp8_meta_forward=tensor._fp8_meta_forward,
fp8_meta_index=tensor._fp8_meta_index,
fp8_dtype=tensor._fp8_dtype,
fp8_scale_inv=tensor._scale_inv,
dtype=tensor.dtype,
)
for key, val in default_kwargs.items():
if key not in init_kwargs:
init_kwargs[key] = val
return Float8Tensor(**init_kwargs)
@staticmethod
def backward(ctx, grad):
return grad.to(ctx.input_dtype), None
class _ViewFunc(torch.autograd.Function):
"""View function
View the Float8Tensor using the provided shape.
"""
@staticmethod
def forward(
ctx,
tensor: torch.Tensor,
shape: Tuple[int] = None,
) -> torch.Tensor:
# Return input tensor if shape is not provided
ctx.shape = tensor.shape
if shape is None:
return tensor
# Construct new tensor if shape is provided
if isinstance(tensor, Float8Tensor):
return Float8Tensor.make_like(
tensor,
data=tensor._data.view(*shape),
)
return tensor.view(*shape)
@staticmethod
def backward(
ctx,
grad: torch.Tensor,
) -> Tuple[Union[torch.Tensor, None], ...]:
if isinstance(grad, Float8Tensor):
dgrad = Float8Tensor.make_like(
grad,
data=grad._data.view(ctx.shape),
)
return dgrad, None
return grad.view(ctx.shape), None
class _ReshapeFunc(torch.autograd.Function):
"""Reshape function
Reshape the Float8Tensor using the provided shape.
"""
@staticmethod
def forward(
ctx,
tensor: torch.Tensor,
shape: Tuple[int] = None,
) -> torch.Tensor:
# Return input tensor if shape is not provided
ctx.shape = tensor.shape
if shape is None:
return tensor
# Construct new tensor if shape is provided
if isinstance(tensor, Float8Tensor):
return Float8Tensor.make_like(
tensor,
data=tensor._data.reshape(*shape),
)
return tensor.reshape(*shape)
@staticmethod
def backward(
ctx,
grad: torch.Tensor,
) -> Tuple[Union[torch.Tensor, None], ...]:
if isinstance(grad, Float8Tensor):
dgrad = Float8Tensor.make_like(
grad,
data=grad._data.reshape(ctx.shape),
)
return dgrad, None
return grad.reshape(ctx.shape), None
class Float8Tensor(torch.Tensor):
"""Experimental tensor class with FP8 data
The tensor presents as having a standard, higher-precision dtype,
but the data itself is (scaled) FP8. For most tensor operations,
the data will be cast to the nominal dtype before performing the
operation.
Parameters
----------
data: torch.Tensor
Raw FP8 data in a uint8 tensor
fp8_attrs: dict, optional
FP8 metadata, primarily managed by Float8Tensor. If
provided, all other FP8 configuration is ignored.
fp8_meta: dict, optional
FP8 metadata object, primarily managed by TE modules.
fp8_meta_forward: bool, default = `True`
Whether to access the FP8 metadata for the
forward pass. Ignored if fp8_meta is not
provided.
fp8_meta_index: int, optional
Index to access in FP8 meta tensors. Required if
fp8_meta is provided and otherwise ignored.
fp8_dtype: transformer_engine_torch.DType, tex.DType.kFloat8E4M3
FP8 format.
fp8_scale_inv: torch.Tensor
Reciprocal of the scaling factor applied when
casting to FP8, i.e. the scaling factor that must
be applied when casting from FP8 to higher
precision. Can be inferred from fp8_meta if
provided.
dtype: torch.dtype, default = torch.float32
Nominal tensor datatype.
"""
def __new__(
cls,
*,
data: torch.Tensor,
fp8_attrs: Optional[Dict[str, Any]] = None,
fp8_meta: Optional[Dict[str, Any]] = None,
fp8_meta_forward: bool = True,
fp8_meta_index: Optional[int] = None,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
fp8_scale_inv: Optional[torch.Tensor] = None,
dtype: torch.dtype = torch.float32,
):
# Check that data buffer is valid
if data.element_size() != 1:
raise ValueError(
f"Float8Tensor requires data buffer with 8-bit dtype (got dtype={data.dtype})"
)
if data.requires_grad:
raise ValueError("Float8Tensor requires non-differentiable data buffer")
if not data.is_cuda:
data = data.cuda()
# Initialize tensor object
self = torch.Tensor._make_wrapper_subclass(
cls,
data.size(),
strides=data.stride(),
storage_offset=data.storage_offset(),
dtype=dtype,
layout=data.layout,
requires_grad=data.requires_grad,
device=data.device,
)
self._data: torch.Tensor = data
# Initialize dict of class attributes
# Note: We store FP8 attributes in a dictionary so we can
# share them between tensors with the same data, e.g. detached
# tensors.
self._fp8_attrs: dict = {}
if fp8_attrs is not None:
self._fp8_attrs = fp8_attrs
return self
# FP8 meta tensors
if fp8_meta is not None and fp8_meta_index is None:
raise ValueError(
"To initialize Float8Tensor with FP8 meta tensors, "
"the FP8 meta tensor index must also be provided"
)
self._fp8_meta: Optional[Dict[str, Any]] = fp8_meta
self._fp8_meta_forward: bool = fp8_meta_forward
self._fp8_meta_index: Optional[int] = fp8_meta_index
# FP8 dtype
assert fp8_dtype in (
tex.DType.kFloat8E4M3,
tex.DType.kFloat8E5M2,
), f"Unsupported fp8_dtype {fp8_dtype}."
self._fp8_dtype: tex.DType = fp8_dtype
# Transposed version of `_data`.
self._transpose: Optional[Float8Tensor] = None
self._transpose_invalid: bool = True
# FP8 scale-inverse
if fp8_scale_inv is None and self._fp8_meta is not None:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=self._fp8_meta_forward,
)
fp8_scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index]
fp8_scale_inv = fp8_scale_inv.detach().view(1).clone()
if fp8_scale_inv is None:
raise ValueError(
"Attempted to initialize Float8Tensor without specifying scale-inverse"
)
if not isinstance(fp8_scale_inv, torch.Tensor):
fp8_scale_inv = torch.full(
[1],
fp8_scale_inv,
dtype=torch.float32,
device=self._data.device,
)
if fp8_scale_inv.numel() != 1:
raise ValueError(
"Attempted to initialize Float8Tensor with invalid scale-inverse tensor"
)
if fp8_scale_inv.dim() != 1:
fp8_scale_inv = fp8_scale_inv.reshape(1)
if fp8_scale_inv.device != self._data.device or fp8_scale_inv.dtype != torch.float32:
fp8_scale_inv = fp8_scale_inv.to(
device=self._data.device,
dtype=torch.float32,
)
self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv
return self
@classmethod
def make_like(
cls,
tensor: Float8Tensor,
*,
data: torch.Tensor,
fp8_attrs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Float8Tensor:
"""Use attributes of a Float8Tensor to create another Float8Tensor
See constructor for list of keyword arguments.
"""
default_kwargs = dict(
fp8_meta=tensor._fp8_meta,
fp8_meta_forward=tensor._fp8_meta_forward,
fp8_meta_index=tensor._fp8_meta_index,
fp8_dtype=tensor._fp8_dtype,
fp8_scale_inv=tensor._scale_inv,
dtype=tensor.dtype,
)
for key, val in default_kwargs.items():
if key not in kwargs:
kwargs[key] = val
return Float8Tensor(data=data, fp8_attrs=fp8_attrs, **kwargs)
def __repr__(self):
return (
"Float8Tensor("
f"fp8_dtype={self._fp8_dtype}, "
f"scale_inv={self._scale_inv.item()}, "
f"data={self.from_float8(dtype=self.dtype)}"
")"
)
def from_float8(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""
Construct plain PyTorch tensor from Float8Tensor
By default the resulting tensor's dtype is the
Float8Tensor's nominal dtype.
"""
return _FromFloat8Func.apply(self, dtype)
@classmethod
def to_float8(
cls,
tensor: torch.Tensor,
*,
fp8_meta: Optional[Dict[str, Any]] = None,
fp8_meta_forward: bool = True,
fp8_meta_index: Optional[int] = None,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
):
"""Construct Float8Tensor from plain PyTorch tensor"""
return _ToFloat8Func.apply(
tensor,
fp8_meta,
fp8_meta_forward,
fp8_meta_index,
fp8_dtype,
scale,
amax,
scale_inv,
)
def float(self) -> torch.Tensor:
return self.from_float8(dtype=torch.float32)
def bfloat16(self) -> torch.Tensor:
return self.from_float8(dtype=torch.bfloat16)
def half(self) -> torch.Tensor:
return self.from_float8(dtype=torch.float16)
def cpu(self) -> torch.Tensor:
return self.from_float8().cpu()
def clone(self) -> Float8Tensor:
return _IdentityFunc.apply(self, {"data": self._data.detach().clone()})
def view(self, *shape: Tuple[int]) -> Float8Tensor:
return _ViewFunc.apply(self, shape)
def reshape(self, *shape: Tuple[int]) -> Float8Tensor:
return _ReshapeFunc.apply(self, shape)
def expand_as(self, other: torch.Tensor):
if other is self:
# Note: expand_as is hackily used to create dummy autograd nodes
# and access the backward graph (see
# https://github.com/pytorch/pytorch/blob/238fb660851268f44ff88127887041fea352fe48/torch/nn/parallel/distributed.py#L1026).
# We equally hackily add a dummy function to handle this
# case.
return _IdentityFunc.apply(self)
return super().expand_as(other)
def contiguous(
self,
*,
memory_format: torch.memory_format = torch.contiguous_format,
) -> Float8Tensor:
"""Returns tensor with data in provided memory format
Returns `self` if data is already in correct memory format.
"""
if self._data.is_contiguous(memory_format=memory_format):
return self
return _IdentityFunc.apply(
self,
{"data": self._data.detach().contiguous(memory_format=memory_format)},
)
def transpose_2d(
self,
*,
force_compute: bool = False,
fill_cache: bool = False,
noop_flag: Optional[torch.Tensor] = None,
cache: Optional[bool] = None,
) -> torch.Tensor:
"""
2D transpose with caching support.
Parameters
----------
force_compute: bool, default = `False`
Force computation of transpose. Otherwise use
cached values, if possible.
fill_cache: bool, default = `False`
Cache output tensor for future function calls.
noop_flag: torch.Tensor, optional
float32 flag indicating whether to avoid updating
cached values, if possible.
cache: bool, deprecated
"""
assert self.dim() == 2, f"{self.dim()}-D transpose not supported."
# Handle deprecated cache kwarg
if cache is not None:
msg = (
"cache kwarg for Float8Tensor.transpose_2d is deprecated, "
"please use force_compute and fill_cache instead"
)
warnings.warn(msg, DeprecationWarning)
if cache:
force_compute = False
fill_cache = True
else:
force_compute = True
fill_cache = False
# Need to compute transpose if cache is invalid
need_compute = force_compute
if self._transpose is None:
need_compute = True
elif self._transpose_invalid:
need_compute = True
# Need to apply transpose kernel if noop flag is applied
if noop_flag is not None:
need_compute = True
# Return cached transpose if possible
if not need_compute:
return self._transpose
# Allocate output if needed
data = self._data.contiguous().reshape(-1, self.size(-1))
out = self._transpose
if out is None:
out = torch.empty(
(data.size(1), data.size(0)),
dtype=torch.uint8,
device=data.device,
)
noop_flag = None
else:
self._transpose_invalid = False
# Apply transpose kernel
fp8_dtype = self._fp8_dtype
if noop_flag is None:
tex.fp8_transpose_noalloc(data, out, fp8_dtype)
else:
noop_flag = noop_flag.to(dtype=torch.float32, device=data.device)
tex.fp8_transpose_noalloc_noop(data, out, noop_flag, fp8_dtype)
# Fill cache if needed
if fill_cache:
self._transpose = out
self._transpose_invalid = False
return out
@torch.no_grad()
def cast_transpose_(
self,
tensor: torch.Tensor,
noop_flag: Optional[torch.Tensor] = None,
) -> None:
"""Cast from tensor and populate transpose cache
Only supported for 2D tensors.
Parameters
----------
tensor: torch.Tensor
Tensor to copy from. Must have same dimensions as
destination tensor.
noop_flag: torch.Tensor, optional
float32 flag indicating whether to avoid updating
destination tensor.
"""
# Make sure tensor is in expected format
data = self._data
if (
tensor.device != data.device
or tensor.dtype not in (torch.float32, torch.float16, torch.bfloat16)
or not tensor.is_contiguous()
):
dtype = tensor.dtype
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
dtype = torch.float32
tensor = tensor.to(
device=self.device,
dtype=dtype,
memory_format=torch.contiguous_format,
)
if tensor.size() != data.size() or data.dim() != 2:
raise ValueError(
"Invalid tensor dimensions for FP8 cast-transpose "
f"(src={tuple(tensor.size())}, dst={tuple(data.size())})"
)
if not data.is_contiguous():
raise ValueError(
"FP8 cast-transpose is only supported for `Float8Tensor`s with contiguous data"
)
if self._fp8_meta is None:
raise ValueError(
"FP8 cast-transpose is only supported for `Float8Tensor`s with FP8 metadata "
)
# Construct transpose cache if needed
transpose = self._transpose
if transpose is None or not transpose.is_contiguous():
transpose = torch.empty(
(data.size(1), data.size(0)),
dtype=torch.uint8,
device=data.device,
)
self._transpose = transpose
noop_flag = None
# Launch cast-transpose kernel
fp8_meta_index = int(self._fp8_meta_index)
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=self._fp8_meta_forward,
)
fp8_meta = self._fp8_meta[fp8_meta_key]
fp8_cast_transpose_fused(
tensor,
fp8_meta,
fp8_meta_index,
self._fp8_dtype,
cast_out=data,
transpose_out=transpose,
scale_inv=self._scale_inv,
noop_flag=noop_flag,
)
self._transpose_invalid = False
@torch.no_grad()
def reset_fp8_meta_scale_inv(self) -> None:
"""Replace FP8 meta tensor scale-inverse with cached value
The FP8 meta tensor scale_inv entry corresponding to this
tensor is replaced with the scale_inv value used to construct
the tensor.
"""
assert self._fp8_meta is not None, "FP8 meta tensors not found."
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=self._fp8_meta_forward,
)
self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index].copy_(self._scale_inv[0])
def to_dtype(self, dtype: torch.dtype) -> Float8Tensor:
"""Create `Float8Tensor` with given nominal dtype
The new tensor has the same underlying FP8 data.
"""
return Float8Tensor.make_like(
self,
data=self._data,
fp8_attrs=self._fp8_attrs,
dtype=dtype,
)
def _reset_caches(self) -> None:
"""
Set transpose cache as invalid.
Should be called after any in-place operation.
"""
self._transpose_invalid = True
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
# In-place copy op
if func == aten.copy_.default:
# Check tensors
dst = args[0]
src = args[1]
if not isinstance(dst, torch.Tensor):
raise RuntimeError("Attempted to copy into something that isn't a PyTorch tensor")
if not isinstance(src, torch.Tensor):
raise RuntimeError("Attempted to copy from something that isn't a PyTorch tensor")
# Special handling based on which tensors are FP8
dst_is_fp8 = isinstance(dst, Float8Tensor)
src_is_fp8 = isinstance(src, Float8Tensor)
if dst_is_fp8 and src_is_fp8:
# Directly copy FP8 data if possible
if dst._fp8_dtype == src._fp8_dtype:
dst._data.copy_(src._data)
dst._scale_inv.copy_(src._scale_inv.detach())
if dst._fp8_meta is not None:
if src._fp8_meta is None:
src_min, src_max = src.from_float8().aminmax()
src_amax = torch.maximum(-src_min, src_max)
else:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=src._fp8_meta_forward,
)
fp8_meta_index = src._fp8_meta_index
src_amax = src._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=dst._fp8_meta_forward,
)
fp8_meta_index = dst._fp8_meta_index
dst_amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]
torch.maximum(src_amax, dst_amax, out=dst_amax)
else:
dst.copy_(src.from_float8())
elif not dst_is_fp8 and src_is_fp8:
# Cast source tensor to higher precision
dst.copy_(src.from_float8())
elif dst_is_fp8 and not src_is_fp8:
# Make sure input is in expected format
src = src.expand(dst.size())
src = src.to(
device=dst.device,
memory_format=torch.contiguous_format,
)
# Update scaling factor if FP8 meta tensors are available
if dst._fp8_meta is None:
scale = dst._scale_inv.reciprocal()
amax = torch.empty_like(scale)
else:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=dst._fp8_meta_forward,
)
fp8_meta_index = dst._fp8_meta_index
scale = dst._fp8_meta[fp8_meta_key].scale[fp8_meta_index]
amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]
# Cast to FP8
if not dst._data.is_contiguous():
raise RuntimeError("Transformer Engine cast kernels require contiguous data")
tex.cast_to_fp8_noalloc(
src.view(1, -1),
scale,
dst._data.view(1, -1),
amax,
dst._scale_inv,
dst._fp8_dtype,
)
# This branch is where the FP8 parameters are updated in-place during optimization.
# Handle forward amax reduction.
post_optimizer_step_fwd_amax_reduction(dst)
else:
# Invalid case
raise RuntimeError("Using Float8Tensor copy logic, but no Float8Tensor found")
# Nothing to return for in-place ops
if dst_is_fp8:
dst._reset_caches()
return None
# Slice op
if func == aten.slice.Tensor:
tensor = args[0]
data = tensor._data
data_slice = data.__torch_dispatch__(
func,
types,
[data] + list(args[1:]),
kwargs,
)
return Float8Tensor.make_like(tensor, data=data_slice)
# Detach op
if func == aten.detach.default:
# Simply return a new Float8Tensor with the same attrs
return Float8Tensor.make_like(
args[0],
data=args[0]._data,
fp8_attrs=args[0]._fp8_attrs,
)
# View op
if func == aten.view.default:
tensor = args[0]
data = tensor._data
data_view = data.__torch_dispatch__(
func,
types,
[data] + list(args[1:]),
kwargs,
)
return Float8Tensor.make_like(
tensor,
data=data_view,
fp8_attrs=tensor._fp8_attrs,
)
def maybe_unwrap(t):
if isinstance(t, Float8Tensor):
return t.from_float8()
return t
def maybe_update_inplace(arg, new_arg, schema_arg):
"""Update values of FP8 tensors
Keep the same FP8 scaling factors.
"""
if (
isinstance(arg, Float8Tensor)
and isinstance(new_arg, torch.Tensor)
and hasattr(schema_arg, "alias_info")
and hasattr(schema_arg.alias_info, "is_write")
and schema_arg.alias_info.is_write
):
arg.copy_(new_arg)
arg._reset_caches()
# In-place op
if func._schema.is_mutable:
# Cast to higher precision, perform op, and cast values
# back to original FP8 buffers
new_args = tree_map(maybe_unwrap, args)
new_kwargs = tree_map(maybe_unwrap, kwargs)
schema_args = func._schema.arguments
args_len = len(args)
out = super().__torch_dispatch__(func, types, new_args, new_kwargs)
for arg, new_arg, schema_arg in zip(args, new_args, schema_args):
maybe_update_inplace(arg, new_arg, schema_arg)
for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]):
assert kwarg == new_kwarg == schema_arg.name, "name of the kw argument should match"
maybe_update_inplace(kwargs[kwarg], new_kwargs[new_kwarg], schema_arg)
return None
# Default op
# Note: cast to higher precision and perform op
args = tree_map(maybe_unwrap, args)
if kwargs is not None:
kwargs = tree_map(maybe_unwrap, kwargs)
out = super().__torch_dispatch__(func, types, args, kwargs)
return out
@classmethod
def _make_in_reduce_ex(
cls,
data: torch.Tensor,
fp8_dtype: tex.DType,
fp8_scale_inv: torch.Tensor,
dtype: torch.dtype,
) -> Float8Tensor:
"""Build Float8Tensor, for use in __reduce__
__reduce_ex__ assumes object constructor has positional
arguments.
"""
return Float8Tensor(
data=data,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
dtype=dtype,
)
def __reduce_ex__(self, protocol: int) -> tuple:
"""Custom pickling to remove references to FP8 metadata objects"""
return (
Float8Tensor._make_in_reduce_ex,
(self._data, self._fp8_dtype, self._scale_inv, self.dtype),
)
def _get_data(self) -> Float8Tensor:
"""Get tensor data property"""
return super().data
def _set_data(self, tensor: torch.Tensor) -> None:
"""Set tensor data property
Cast tensor to FP8 and store in FP8 buffer.
"""
with torch.no_grad():
self.copy_(tensor)
# Cast to FP8 when setting Float8Tensor.data
data = property(_get_data, _set_data)
# Accessors for objects in self._fp8_attrs
# Note: We store FP8 attributes in a dictionary so we can share
# them between tensors with the same data, e.g. detached tensors.
# For convenience, we also expose them as property attributes.
_fp8_meta = property(**_make_fp8_attr_property_funcs("fp8_meta"))
_fp8_meta_forward = property(**_make_fp8_attr_property_funcs("fp8_meta_forward"))
_fp8_meta_index = property(**_make_fp8_attr_property_funcs("fp8_meta_index"))
_fp8_dtype = property(**_make_fp8_attr_property_funcs("dtype"))
_transpose = property(**_make_fp8_attr_property_funcs("transpose"))
_transpose_invalid = property(**_make_fp8_attr_property_funcs("transpose_invalid"))
_scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv"))
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return torch._C._disabled_torch_function_impl(func, types, args, kwargs)
__all__ = ["Float8Tensor"]
......@@ -865,11 +865,17 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# If primary weights are in fp8, wrap the parameter as Float8Tensor
fp8_meta_index = self.param_init_meta[name].fp8_meta_index
if self.primary_weights_in_fp8 and fp8_meta_index is not None:
dummy_amax = torch.empty(
(1, 1),
dtype=torch.float32,
device=param.device,
) # Dummy buffer to avoid overwriting amax history
param = Float8Tensor.to_float8(
param,
fp8_meta=self.fp8_meta,
fp8_meta_index=fp8_meta_index,
amax=torch.empty(1, device="cuda"), # Dummy amax to avoid overwriting history.
amax=dummy_amax,
with_transpose_cache=torch.is_grad_enabled(),
)
# Redo parameter wrap in case we broke it above
......@@ -891,7 +897,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
cache_name: Optional[str] = None,
update_workspace: bool = True,
skip_update_flag: Optional[torch.Tensor] = None,
with_transpose: bool = False,
fsdp_group: dist_group_type = None,
) -> Float8Tensor:
"""Get FP8 workspace buffer and maybe update its values
......@@ -917,27 +922,30 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
skip_update_flag: torch.Tensor, optional
GPU flag to skip updating the workspace. Take precedence
over `update_workspace` if provided.
with_transpose: bool, default = `False`
Whether to initialize cached transpose in workspace.
fsdp_group: bool, default = None
FSDP process group that the weights are distributed over.
"""
# Construct workspace if needed
# Try getting workspace from cache
out = None
if cache_name is not None:
out = self._fp8_workspaces.get(cache_name, None)
# Gather cached Fp8 workspace if it's distributed
# NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
# for models initialized with Fp8 primary weights.
if (
not isinstance(out, Float8Tensor)
and fsdp_group is not None
and out._data.shape != tensor.data.shape
):
_fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out)
# Gather cached Fp8 workspace if it's distributed
# NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
# for models initialized with Fp8 primary weights.
if (
out is not None
and not isinstance(out, Float8Tensor)
and fsdp_group is not None
and out._data.shape != tensor.data.shape
):
_fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out)
# Construct workspace if needed
if out is None:
# FP8 data
if tensor is None or fp8_meta_forward is None or fp8_meta_index is None:
raise ValueError(
"tensor, fp8_meta_forward, and fp8_meta_index kwargs "
......@@ -947,16 +955,38 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta["recipe"],
fprop_tensor=fp8_meta_forward,
)
data = torch.empty_like(tensor, dtype=torch.uint8)
scale_inv = torch.empty([1], dtype=torch.float32, device=tensor.device)
# Transpose cache
with_transpose_cache = torch.is_grad_enabled()
if (
not with_transpose_cache
and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
with_transpose_cache = True
data_transpose = None
if with_transpose_cache:
data_transpose = torch.empty(
(tensor.size(-1), tensor.numel() // tensor.size(-1)),
dtype=torch.uint8,
device=tensor.device,
)
# Construct FP8 tensor
out = Float8Tensor(
data=torch.empty_like(tensor, dtype=torch.uint8),
data=data,
fp8_meta=self.fp8_meta,
fp8_meta_forward=fp8_meta_forward,
fp8_meta_index=fp8_meta_index,
fp8_dtype=fp8_dtype,
fp8_scale_inv=scale_inv,
dtype=tensor.dtype,
data_transpose=data_transpose,
)
# Update cache
if cache_name is not None:
self._fp8_workspaces[cache_name] = out
update_workspace = True
......@@ -968,33 +998,17 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if update_workspace:
if tensor is None:
raise ValueError("tensor kwarg must be provided to update FP8 workspace")
if with_transpose:
out.cast_transpose_(
tensor,
noop_flag=skip_update_flag,
)
if is_in_onnx_export_mode():
# ONNX export does not support fused cast-transpose
# kernel and requires that FP8 scales can be
# represented with constant ops.
transpose_cache = out._transpose
out._transpose = None
out.quantize_(tensor)
out._scale_inv.fill_(out._scale_inv.item())
out._transpose = transpose_cache
else:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=out._fp8_meta_forward,
)
fp8_meta = out._fp8_meta[fp8_meta_key]
fp8_meta_index = out._fp8_meta_index
cast_to_fp8(
tensor,
fp8_meta,
fp8_meta_index,
out._fp8_dtype,
out=out._data,
)
if is_in_onnx_export_mode():
# ONNX export expects FP8 scales can be
# represented with constant ops. However, copying
# into a buffer involves an expand op for array
# broadcasting. We work around this by filling the
# buffer instead.
out._scale_inv.fill_(fp8_meta.scale_inv[fp8_meta_index].item())
else:
out._scale_inv.copy_(fp8_meta.scale_inv[fp8_meta_index])
out.quantize_(tensor, noop_flag=skip_update_flag)
return out
......
......@@ -28,8 +28,6 @@ from ..utils import (
from ..distributed import (
set_tensor_model_parallel_attributes,
get_distributed_world_size,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
)
from ..cpp_extensions import (
cast_to_fp8,
......@@ -760,22 +758,12 @@ class GroupedLinear(TransformerEngineBaseModule):
weight_tensors_fp8 = [None] * self.num_gemms
if self.fp8:
with_transpose = torch.is_grad_enabled()
if (
not with_transpose
and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
with_transpose = True
for i in range(self.num_gemms):
if isinstance(weight_tensors[i], Float8Tensor):
# Fill transpose cache in FP8 tensor if needed
update_transpose_cache = with_transpose
if update_transpose_cache:
update_transpose_cache = (
is_first_microbatch or skip_fp8_weight_update is not None
)
if update_transpose_cache:
# Make sure transpose cache is valid, if present
# Note: Transpose cache may have been invalidated
# externally, e.g. by optimizer.
if weight_tensors[i]._transpose is not None:
weight_tensors[i].transpose_2d(
fill_cache=True,
noop_flag=skip_fp8_weight_update,
......@@ -790,7 +778,6 @@ class GroupedLinear(TransformerEngineBaseModule):
cache_name=(None if is_first_microbatch is None else f"weight{i}"),
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
with_transpose=with_transpose,
)
from ..cpu_offload import CPUOffloadEnabled
......
......@@ -36,8 +36,6 @@ from ..distributed import (
allreduce,
reduce_scatter_along_first_dim,
gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
_fsdp_scatter_tensors,
_fsdp_gather_tensors,
)
......@@ -47,6 +45,7 @@ from ..graph import is_graph_capturing
from ._common import _apply_normalization, _noop_cat
from ..float8_tensor import Float8Tensor
from ..export import is_in_onnx_export_mode
from ..tensor import QuantizedTensor
__all__ = ["LayerNormLinear"]
......@@ -1151,14 +1150,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
# Get concatenated weight and bias tensors
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, Float8Tensor) for w in unfused_weights):
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting Float8Tensor into multiple params is not supported"
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
unfused_weights = [w.from_float8() for w in unfused_weights]
unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = _noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = _noop_cat(
......@@ -1170,32 +1169,18 @@ class LayerNormLinear(TransformerEngineBaseModule):
# Initialize FP8 weights if needed
weight_fp8 = None
if self.fp8:
with_transpose = torch.is_grad_enabled()
if (
not with_transpose
and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
with_transpose = True
if isinstance(weight_tensor, Float8Tensor):
# Fill transpose cache in FP8 tensor if needed
update_transpose_cache = with_transpose
if update_transpose_cache:
update_transpose_cache = (
is_first_microbatch or skip_fp8_weight_update is not None
)
if update_transpose_cache:
# Make sure transpose cache is valid, if present
# Note: Transpose cache may have been invalidated
# externally, e.g. by optimizer.
if weight_tensor._transpose is not None:
weight_tensor.transpose_2d(
fill_cache=True,
noop_flag=skip_fp8_weight_update,
)
else:
# FP8 cast to workspace buffer
update_workspace = (
is_first_microbatch is None
or is_first_microbatch
or skip_fp8_weight_update is not None
)
update_workspace = is_first_microbatch is None or is_first_microbatch
weight_fp8 = self.get_fp8_workspace(
tensor=weight_tensor,
fp8_meta_forward=True,
......@@ -1203,7 +1188,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
cache_name=(None if is_first_microbatch is None else "weight"),
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
with_transpose=with_transpose,
)
from ..cpu_offload import CPUOffloadEnabled
......
......@@ -42,8 +42,6 @@ from ..distributed import (
allreduce,
reduce_scatter_along_first_dim,
gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
use_reentrant_activation_recompute,
_fsdp_scatter_tensors,
_fsdp_gather_tensors,
......@@ -1485,19 +1483,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_weight_fp8 = None
if self.fp8:
update_workspace = is_first_microbatch is None or is_first_microbatch
with_transpose = torch.is_grad_enabled()
if (
is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
with_transpose = True
update_transpose_cache = with_transpose
if update_transpose_cache:
update_transpose_cache = (
is_first_microbatch or skip_fp8_weight_update is not None
)
if isinstance(fc1_weight, Float8Tensor):
if update_transpose_cache:
if fc1_weight._transpose is not None:
fc1_weight.transpose_2d(
fill_cache=True,
noop_flag=skip_fp8_weight_update,
......@@ -1513,10 +1500,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
cache_name=cache_name,
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
with_transpose=with_transpose,
)
if isinstance(fc2_weight, Float8Tensor):
if update_transpose_cache:
if fc2_weight._transpose is not None:
fc2_weight.transpose_2d(
fill_cache=True,
noop_flag=skip_fp8_weight_update,
......@@ -1532,7 +1518,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
cache_name=cache_name,
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
with_transpose=with_transpose,
)
# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
......
......@@ -33,8 +33,6 @@ from ..distributed import (
allreduce,
reduce_scatter_along_first_dim,
gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
_fsdp_scatter_tensors,
_fsdp_gather_tensors,
)
......@@ -49,6 +47,7 @@ from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor
from ..export import is_in_onnx_export_mode
from ..tensor import QuantizedTensor
__all__ = ["Linear"]
......@@ -938,19 +937,19 @@ class Linear(TransformerEngineBaseModule):
with self.prepare_forward(
inp,
is_first_microbatch,
allow_non_contiguous=isinstance(inp, Float8Tensor),
allow_non_contiguous=isinstance(inp, QuantizedTensor),
) as inp:
# Get concatenated weight and bias tensors
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, Float8Tensor) for w in unfused_weights):
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting Float8Tensor into multiple params is not supported"
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
unfused_weights = [w.from_float8() for w in unfused_weights]
unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = _noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = _noop_cat(
......@@ -962,21 +961,11 @@ class Linear(TransformerEngineBaseModule):
# Initialize FP8 weights if needed
weight_fp8 = None
if self.fp8:
with_transpose = torch.is_grad_enabled()
if (
not with_transpose
and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
with_transpose = True
if isinstance(weight_tensor, Float8Tensor):
# Fill transpose cache in FP8 tensor if needed
update_transpose_cache = with_transpose
if update_transpose_cache:
update_transpose_cache = (
is_first_microbatch or skip_fp8_weight_update is not None
)
if update_transpose_cache:
# Make sure transpose cache is valid, if present
# Note: Transpose cache may have been invalidated
# externally, e.g. by optimizer.
if weight_tensor._transpose is not None:
weight_tensor.transpose_2d(
fill_cache=True,
noop_flag=skip_fp8_weight_update,
......@@ -991,7 +980,6 @@ class Linear(TransformerEngineBaseModule):
cache_name=(None if is_first_microbatch is None else "weight"),
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
with_transpose=with_transpose,
fsdp_group=self.fsdp_group,
)
......
......@@ -9,54 +9,12 @@ from typing import Any, Iterable, Optional
import torch
from transformer_engine.pytorch.float8_tensor import Float8Tensor
def canonicalize_device(device: Optional[torch.device | str]) -> torch.device:
"""Canonicalize PyTorch device
If `None`, then returns the default CUDA device.
"""
if device is None:
# Use default CUDA device
device = torch.get_default_device()
if device.type != "cuda":
device = torch.device("cuda", torch.cuda.current_device())
elif not isinstance(device, torch.device):
device = torch.device(device)
if device.type == "cuda" and device.index is None:
device = torch.device("cuda", torch.cuda.current_device())
return device
def canonicalize_dtype(dtype: Optional[torch.dtype]) -> torch.dtype:
"""Canonicalize PyTorch datatype
If `None`, then returns the default PyTorch datatype.
"""
if dtype is None:
# Use default dtype
dtype = torch.get_default_dtype()
return dtype
def devices_match(device1: torch.device, device2: torch.device) -> bool:
"""Whether two devices are the same"""
device1 = torch.device(device1)
device2 = torch.device(device2)
if device1.type != device2.type:
return False
if device1.type == "cuda":
index1 = device1.index
index2 = device2.index
if index1 is None:
index1 = torch.cuda.current_device()
if index2 is None:
index2 = torch.cuda.current_device()
return index1 == index2
return device1 == device2
from ..tensor import Float8Tensor
from ..utils import (
canonicalize_device, # pylint: disable=unused-import
canonicalize_dtype, # pylint: disable=unused-import
devices_match, # pylint: disable=unused-import
)
def is_float8_tensor(tensor: Any) -> bool:
......@@ -92,7 +50,13 @@ def convert_tensor(
# Convert FP8 tensor
if is_float8_tensor(tensor):
data = tensor._data.to(device=device, memory_format=memory_format)
data = tensor._data
if not devices_match(device, data.device):
data = data.to(device=device)
if memory_format != torch.preserve_format and not data.is_contiguous(
memory_format=memory_format
):
data = data.contiguous(memory_format=memory_format)
return Float8Tensor.make_like(
tensor,
data=data,
......
......@@ -9,11 +9,8 @@ from typing import Optional
import torch
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from .._common import is_float8_tensor
from ...tensor import QuantizedTensor
from ..op import BasicOperation, OperationContext
class AllReduce(BasicOperation):
......@@ -54,8 +51,8 @@ class AllReduce(BasicOperation):
# Perform all-reduce
x = input_
if is_float8_tensor(x):
x = x.from_float8()
if isinstance(x, QuantizedTensor):
x = x.dequantize()
x = x.contiguous()
torch.distributed.all_reduce(x, group=self.process_group)
return x
......
......@@ -289,10 +289,18 @@ class BasicLinear(BasicOperation):
# Cast to FP8 if needed
if self._with_fp8_parameters:
dummy_amax = torch.empty(
(1, 1),
dtype=torch.float32,
device=self.device,
) # Dummy buffer to avoid overwriting amax history
weight = Float8Tensor.to_float8(
weight,
fp8_meta=self.get_fp8_meta("param"),
fp8_meta_forward=True,
fp8_meta_index=0,
amax=dummy_amax,
with_transpose_cache=torch.is_grad_enabled(),
)
# Save updated parameter
......@@ -467,25 +475,19 @@ class BasicLinear(BasicOperation):
input_fp8_meta["recipe"],
fprop_tensor=True,
)
x_fp8 = Float8Tensor(
data=torch.empty_like(x_local, dtype=torch.uint8),
with_transpose_cache = weight.requires_grad
if tensor_parallel_mode == "column" and sequence_parallel:
with_transpose_cache = False
x_local = Float8Tensor.to_float8(
x_local,
fp8_meta=input_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device),
dtype=dtype,
with_transpose_cache=with_transpose_cache,
)
with_cast_transpose = weight.requires_grad
if tensor_parallel_mode == "column" and sequence_parallel:
with_cast_transpose = False
if with_cast_transpose:
x_fp8.cast_transpose_(x_local)
else:
x_fp8.copy_(x_local)
x_local = x_fp8
elif not with_fp8_compute and is_float8_tensor(x_local):
x_local = x_local.from_float8()
x_local = x_local.dequantize()
x = x_local
x_async = None
if tensor_parallel_mode == "column" and sequence_parallel:
......@@ -510,11 +512,12 @@ class BasicLinear(BasicOperation):
w = Float8Tensor.to_float8(
w,
fp8_meta=weight_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
)
elif not with_fp8_compute and is_float8_tensor(w):
w = w.from_float8()
w = w.dequantize()
# Check bias tensor
b = None
......@@ -815,25 +818,19 @@ class BasicLinear(BasicOperation):
grad_output_fp8_meta["recipe"],
fprop_tensor=False,
)
dy_fp8 = Float8Tensor(
data=torch.empty_like(dy, dtype=torch.uint8),
with_transpose_cache = weight_requires_grad
if tensor_parallel_mode == "row" and sequence_parallel:
with_transpose_cache = False
dy = Float8Tensor.to_float8(
dy,
fp8_meta=grad_output_fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device),
dtype=dtype,
with_transpose_cache=with_transpose_cache,
)
with_cast_transpose = weight_requires_grad
if tensor_parallel_mode == "row" and sequence_parallel:
with_cast_transpose = False
if with_cast_transpose:
dy_fp8.cast_transpose_(dy)
else:
dy_fp8.copy_(dy)
dy = dy_fp8
elif not with_fp8_compute and is_float8_tensor(dy):
dy = dy.from_float8()
dy = dy.dequantize()
if tensor_parallel_mode == "row" and sequence_parallel:
dy, dy_async = gather_along_first_dim(
dy,
......@@ -853,26 +850,24 @@ class BasicLinear(BasicOperation):
device=device,
dtype=dtype,
)
x_is_sharded = tensor_parallel_mode == "column" and sequence_parallel
if with_fp8_compute and not is_float8_tensor(x_local):
fp8_dtype = get_fp8_te_dtype(
input_fp8_meta["recipe"],
fprop_tensor=True,
)
x_fp8 = Float8Tensor(
data=torch.empty_like(x_local, dtype=torch.uint8),
x_local = Float8Tensor.to_float8(
x_local,
fp8_meta=input_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device),
dtype=dtype,
with_transpose_cache=(not x_is_sharded),
)
x_fp8.cast_transpose_(x_local)
x_local = x_fp8
elif not with_fp8_compute and is_float8_tensor(x_local):
x_local = x_local.from_float8()
x = x_local
if tensor_parallel_mode == "column" and sequence_parallel:
if x_is_sharded:
x, x_async = gather_along_first_dim(
x_local,
tensor_parallel_group,
......@@ -898,19 +893,16 @@ class BasicLinear(BasicOperation):
weight_fp8_meta["recipe"],
fprop_tensor=True,
)
w_fp8 = Float8Tensor(
data=torch.empty_like(w, dtype=torch.uint8),
w = Float8Tensor.to_float8(
w,
fp8_meta=weight_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device),
dtype=dtype,
with_transpose_cache=True,
)
w_fp8.cast_transpose_(w)
w = w_fp8
elif not with_fp8_compute and is_float8_tensor(w):
w = w.from_float8()
w = w.dequantize()
# Construct grad input tensor
if grad_input is not None:
......
......@@ -9,12 +9,9 @@ from typing import Optional
import torch
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from .._common import convert_tensor, is_float8_tensor
from ...tensor import Float8Tensor, QuantizedTensor
from ..op import BasicOperation, OperationContext
from .._common import convert_tensor
class ReduceScatter(BasicOperation):
......@@ -63,8 +60,8 @@ class ReduceScatter(BasicOperation):
# Check input tensor
x = input_
if is_float8_tensor(x):
x = x.from_float8()
if isinstance(x, QuantizedTensor):
x = x.dequantize()
x = x.contiguous()
# Perform reduce-scatter
......@@ -96,7 +93,7 @@ class ReduceScatter(BasicOperation):
# Perform all-gather
dy = convert_tensor(grad_output, memory_format=torch.contiguous_format)
dx = None
if is_float8_tensor(dy):
if isinstance(dy, Float8Tensor):
dx = Float8Tensor.make_like(
dy,
data=torch.empty(
......@@ -111,6 +108,8 @@ class ReduceScatter(BasicOperation):
group=self.process_group,
)
else:
if isinstance(dy, QuantizedTensor):
dy = dy.dequantize()
dx = torch.empty(input_dims, dtype=dy.dtype, device=dy.device)
torch.distributed.all_gather_into_tensor(
dx,
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Custom tensor classes"""
from .float8_tensor import Float8Tensor
from .quantized_tensor import QuantizedTensor
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tensor class with FP8 data"""
from __future__ import annotations
from typing import Any, Dict, Optional, Tuple
import warnings
import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from ..constants import TE_DType as torch_to_transformer_engine_dtype
from ..cpp_extensions import (
cast_from_fp8,
cast_to_fp8,
fp8_cast_transpose_fused,
)
from ..fp8 import FP8GlobalStateManager
from ..utils import devices_match
from .quantized_tensor import QuantizedTensor
aten = torch.ops.aten
updated_fp8_params = {}
def _make_fp8_attr_property_funcs(name: str) -> Any:
"""Make accessors for an FP8 attribute
We store FP8 attributes in a dictionary so we can share them
between tensors with the same data, e.g. detached tensors. For
convenience, we also expose them as property attributes. This
function creates the accessors for property attributes.
Parameters
----------
name: str
Key in dictionary of FP8 attributes
"""
def get_func(self) -> Any:
return self._fp8_attrs[name]
def set_func(self, value: Any) -> None:
self._fp8_attrs[name] = value
def del_func(self) -> None:
del self._fp8_attrs[name]
return dict(fget=get_func, fset=set_func, fdel=del_func)
class _FromFloat8Func(torch.autograd.Function):
"""Cast from FP8 to other dtype"""
@staticmethod
def forward(
_ctx: torch.autograd.function.FunctionCtx, # unused
tensor: Float8Tensor,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
return tensor.dequantize(dtype=dtype)
@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# Assume that we want gradients in full precision
return grad, None
def post_optimizer_step_fwd_amax_reduction(param: Float8Tensor) -> None:
"""Amax scale and update when there is at least 1 trainable FP8 parameter."""
param_id = id(param._data)
if param_id not in FP8GlobalStateManager.fp8_param_to_autocast:
return
autocast_key = FP8GlobalStateManager.fp8_param_to_autocast[param_id]
if autocast_key not in FP8GlobalStateManager.autocast_to_fp8_params:
return
if autocast_key in updated_fp8_params:
updated_fp8_params[autocast_key].add(param_id)
else:
updated_fp8_params[autocast_key] = {param_id}
current_fp8_params_set = FP8GlobalStateManager.autocast_to_fp8_params[autocast_key]
# All FP8 trainable parameters have been updated.
if updated_fp8_params[autocast_key] == current_fp8_params_set:
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=True, fp8_weights=True)
del updated_fp8_params[autocast_key]
class _ToFloat8Func(torch.autograd.Function):
"""Cast to FP8 from other dtype"""
@staticmethod
def forward(
_ctx: torch.autograd.function.FunctionCtx, # unused
tensor: torch.Tensor,
fp8_meta: Optional[Dict[str, Any]] = None,
fp8_meta_forward: bool = True,
fp8_meta_index: Optional[int] = None,
fp8_dtype: TE_DType = TE_DType.kFloat8E4M3,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
with_transpose_cache: bool = False,
) -> Float8Tensor:
# Tensor attributes
dtype = tensor.dtype
if dtype not in (torch.float32, torch.bfloat16, torch.float16):
dtype = torch.float32
device = tensor.device
if device.type != "cuda":
device = torch.device("cuda")
# FP8 data buffer
data = torch.empty(tensor.size(), dtype=torch.uint8, device=device)
# Check scale
if scale is None and fp8_meta is None:
scale = 1
if scale is not None:
if isinstance(scale, torch.Tensor):
scale = scale.to(device=device, dtype=torch.float32)
else:
scale = torch.full([1], scale, dtype=torch.float32, device=device)
# Check scale-inverse
if scale_inv is None:
scale_inv = torch.empty([1], dtype=torch.float32, device=device)
elif not devices_match(scale_inv.device, device) or scale_inv.dtype != dtype:
scale_inv = scale_inv.to(device=device, dtype=torch.float32)
# Transpose cache
data_transpose = None
if with_transpose_cache:
data_transpose = torch.empty(
(data.size(-1), data.numel() // data.size(-1)),
dtype=torch.uint8,
device=tensor.device,
)
# Construct FP8 tensor
out = Float8Tensor(
data=data,
fp8_meta=fp8_meta,
fp8_meta_forward=fp8_meta_forward,
fp8_meta_index=fp8_meta_index,
fp8_dtype=fp8_dtype,
fp8_scale_inv=scale_inv,
dtype=dtype,
data_transpose=data_transpose,
)
# Cast to FP8 tensor
out.quantize_(tensor, scale=scale, amax=amax)
return out
@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# Assume that we want gradients in full precision
return grad, None, None, None, None, None, None, None
class _IdentityFunc(torch.autograd.Function):
"""Identity function
If constructor keyword-arguments are provided, then construct a
new Float8Tensor using the provided tensor's attributes.
"""
@staticmethod
def forward(
ctx,
tensor: Float8Tensor,
init_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
# Return input tensor if constructor kwargs are not provided
ctx.input_dtype = tensor.dtype
if init_kwargs is None:
return tensor
# Construct new tensor if constructor kwargs are provided
default_kwargs = dict(
data=tensor._data,
fp8_meta=tensor._fp8_meta,
fp8_meta_forward=tensor._fp8_meta_forward,
fp8_meta_index=tensor._fp8_meta_index,
fp8_dtype=tensor._fp8_dtype,
fp8_scale_inv=tensor._scale_inv,
dtype=tensor.dtype,
)
for key, val in default_kwargs.items():
if key not in init_kwargs:
init_kwargs[key] = val
return Float8Tensor(**init_kwargs)
@staticmethod
def backward(ctx, grad):
return grad.to(ctx.input_dtype), None
class _ViewFunc(torch.autograd.Function):
"""View function
View the Float8Tensor using the provided shape.
"""
@staticmethod
def forward(
ctx,
tensor: torch.Tensor,
shape: Tuple[int] = None,
) -> torch.Tensor:
# Return input tensor if shape is not provided
ctx.shape = tensor.shape
if shape is None:
return tensor
# Construct new tensor if shape is provided
if isinstance(tensor, Float8Tensor):
return Float8Tensor.make_like(
tensor,
data=tensor._data.view(*shape),
)
return tensor.view(*shape)
@staticmethod
def backward(
ctx,
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
if isinstance(grad, Float8Tensor):
dgrad = Float8Tensor.make_like(
grad,
data=grad._data.view(ctx.shape),
)
return dgrad, None
return grad.view(ctx.shape), None
class _ReshapeFunc(torch.autograd.Function):
"""Reshape function
Reshape the Float8Tensor using the provided shape.
"""
@staticmethod
def forward(
ctx,
tensor: torch.Tensor,
shape: Tuple[int] = None,
) -> torch.Tensor:
# Return input tensor if shape is not provided
ctx.shape = tensor.shape
if shape is None:
return tensor
# Construct new tensor if shape is provided
if isinstance(tensor, Float8Tensor):
return Float8Tensor.make_like(
tensor,
data=tensor._data.reshape(*shape),
)
return tensor.reshape(*shape)
@staticmethod
def backward(
ctx,
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
if isinstance(grad, Float8Tensor):
dgrad = Float8Tensor.make_like(
grad,
data=grad._data.reshape(ctx.shape),
)
return dgrad, None
return grad.reshape(ctx.shape), None
class Float8Tensor(QuantizedTensor):
"""Experimental tensor class with FP8 data
The tensor presents as having a standard, higher-precision dtype,
but the data itself is (scaled) FP8. For most tensor operations,
the data will be cast to the nominal dtype before performing the
operation.
Parameters
----------
data: torch.Tensor
Raw FP8 data in a uint8 tensor
fp8_attrs: dict, optional
FP8 metadata, primarily managed by Float8Tensor. If
provided, all other FP8 configuration is ignored.
fp8_meta: dict, optional
FP8 metadata object, primarily managed by TE modules.
fp8_meta_forward: bool, default = `True`
Whether to access the FP8 metadata for the
forward pass. Ignored if fp8_meta is not
provided.
fp8_meta_index: int, optional
Index to access in FP8 meta tensors. Required if
fp8_meta is provided and otherwise ignored.
fp8_dtype: transformer_engine_torch.DType, default = kFloat8E4M3
FP8 format.
fp8_scale_inv: torch.Tensor
Reciprocal of the scaling factor applied when
casting to FP8, i.e. the scaling factor that must
be applied when casting from FP8 to higher
precision. Can be inferred from fp8_meta if
provided.
dtype: torch.dtype, default = torch.float32
Nominal tensor datatype.
"""
def __new__(
cls,
*,
data: torch.Tensor,
fp8_attrs: Optional[Dict[str, Any]] = None,
fp8_meta: Optional[Dict[str, Any]] = None,
fp8_meta_forward: bool = True,
fp8_meta_index: Optional[int] = None,
fp8_dtype: TE_DType = TE_DType.kFloat8E4M3,
fp8_scale_inv: Optional[torch.Tensor] = None,
dtype: torch.dtype = torch.float32,
data_transpose: Optional[torch.Tensor] = None,
):
# Check that data buffer is valid
if data.element_size() != 1:
raise ValueError(
f"Float8Tensor requires data buffer with 8-bit dtype (got dtype={data.dtype})"
)
if data.requires_grad:
raise ValueError("Float8Tensor requires non-differentiable data buffer")
if not data.is_cuda:
data = data.cuda()
# Initialize tensor object
self = torch.Tensor._make_wrapper_subclass(
cls,
data.size(),
strides=data.stride(),
storage_offset=data.storage_offset(),
dtype=dtype,
layout=data.layout,
requires_grad=data.requires_grad,
device=data.device,
)
self._data: torch.Tensor = data
# Initialize dict of class attributes
# Note: We store FP8 attributes in a dictionary so we can
# share them between tensors with the same data, e.g. detached
# tensors.
self._fp8_attrs: dict
if fp8_attrs is None:
self._fp8_attrs = {}
else:
self._fp8_attrs = fp8_attrs
return self
# FP8 meta tensors
if fp8_meta is not None and fp8_meta_index is None:
raise ValueError(
"To initialize Float8Tensor with FP8 meta tensors, "
"the FP8 meta tensor index must also be provided"
)
self._fp8_meta: Optional[Dict[str, Any]] = fp8_meta
self._fp8_meta_forward: bool = fp8_meta_forward
self._fp8_meta_index: Optional[int] = fp8_meta_index
# FP8 dtype
assert fp8_dtype in (
TE_DType.kFloat8E4M3,
TE_DType.kFloat8E5M2,
), f"Unsupported fp8_dtype {fp8_dtype}."
self._fp8_dtype: TE_DType = fp8_dtype
# FP8 scale-inverse
if fp8_scale_inv is None and self._fp8_meta is not None:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=self._fp8_meta_forward,
)
fp8_scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index]
fp8_scale_inv = fp8_scale_inv.detach().view(1).clone()
if fp8_scale_inv is None:
raise ValueError(
"Attempted to initialize Float8Tensor without specifying scale-inverse"
)
if not isinstance(fp8_scale_inv, torch.Tensor):
fp8_scale_inv = torch.full(
[1],
fp8_scale_inv,
dtype=torch.float32,
device=self._data.device,
)
if fp8_scale_inv.numel() != 1:
raise ValueError(
"Attempted to initialize Float8Tensor with invalid scale-inverse tensor"
)
if fp8_scale_inv.dim() != 1:
fp8_scale_inv = fp8_scale_inv.reshape(1)
if (
not devices_match(fp8_scale_inv.device, self._data.device)
or fp8_scale_inv.dtype != torch.float32
):
fp8_scale_inv = fp8_scale_inv.to(
device=self._data.device,
dtype=torch.float32,
)
self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv
# FP8 transpose cache
self._transpose: Optional[Float8Tensor] = data_transpose
self._transpose_invalid: bool = self._transpose is None
return self
@classmethod
def make_like(
cls,
tensor: Float8Tensor,
*,
data: torch.Tensor,
fp8_attrs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Float8Tensor:
"""Use attributes of a Float8Tensor to create another Float8Tensor
See constructor for list of keyword arguments.
"""
default_kwargs = dict(
fp8_meta=tensor._fp8_meta,
fp8_meta_forward=tensor._fp8_meta_forward,
fp8_meta_index=tensor._fp8_meta_index,
fp8_dtype=tensor._fp8_dtype,
fp8_scale_inv=tensor._scale_inv,
dtype=tensor.dtype,
)
for key, val in default_kwargs.items():
if key not in kwargs:
kwargs[key] = val
return Float8Tensor(data=data, fp8_attrs=fp8_attrs, **kwargs)
def __repr__(self):
return (
"Float8Tensor("
f"fp8_dtype={self._fp8_dtype}, "
f"scale_inv={self._scale_inv.item()}, "
f"data={self.from_float8(dtype=self.dtype)}"
")"
)
def dequantize(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
# Convert PyTorch dtype to TE dtype
if dtype is None:
dtype = self.dtype
dtype = torch_to_transformer_engine_dtype[dtype]
# Make sure FP8 data is in expected format
data = self._data
if data.device.type != "cuda":
data = data.cuda()
if not data.is_contiguous():
data = data.contiguous()
if data.dim() != 2:
data = data.view(1, -1)
# Cast from FP8
out = cast_from_fp8(
data.view(1, -1),
None, # fp8_meta_tensor
None, # fp8_tensor
self._fp8_dtype,
dtype,
scale_inv=self._scale_inv,
)
# Make sure output is in expected format
if out.size() != self.size():
out = out.view(self.size())
return out
def from_float8(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""
Construct plain PyTorch tensor from Float8Tensor
By default the resulting tensor's dtype is the
Float8Tensor's nominal dtype.
"""
return _FromFloat8Func.apply(self, dtype)
def quantize_(
self,
tensor: torch.Tensor,
*,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
noop_flag: Optional[torch.Tensor] = None,
) -> Float8Tensor:
"""Update FP8 data
Parameters
----------
tensor: torch.Tensor
Tensor to copy from
scale: torch.Tensor, optional
Scaling factor to use for FP8 quantization
amax: torch.Tensor, optional
History of maximum absolute values. The first entry will
be updated with the absmax of `tensor`.
noop_flag: torch.Tensor, optional
float32 flag indicating whether to avoid performing update
"""
src = tensor
dst = self
# In-place operations invalidate transpose cache
self._reset_caches()
# Special logic if other tensor is Float8Tensor
if isinstance(src, Float8Tensor):
# Cast to plain tensor if FP8 dtypes don't match
if dst._fp8_dtype != src._fp8_dtype:
return dst.quantize_(src.dequantize())
# Directly copy FP8 data
dst._data.copy_(src._data.detach())
dst._scale_inv.copy_(src._scale_inv.detach())
if amax is not None or dst._fp8_meta is not None:
src_amax: torch.Tensor
if src._fp8_meta is None:
src_min, src_max = src.dequantize().aminmax()
src_amax = torch.maximum(-src_min, src_max)
else:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=src._fp8_meta_forward,
)
fp8_meta_index = src._fp8_meta_index
src_amax = src._fp8_meta[fp8_meta_key].amax_history[0, fp8_meta_index]
dst_amax: torch.Tensor
if amax is None:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=dst._fp8_meta_forward,
)
fp8_meta_index = dst._fp8_meta_index
dst_amax = dst._fp8_meta[fp8_meta_key].amax_history[0, fp8_meta_index]
else:
dst_amax = amax
if dst_amax.dim() > 0:
dst_amax = dst_amax[tuple([0] * dst_amax.dim())]
torch.maximum(src_amax, dst_amax, out=dst_amax)
if dst._transpose is not None:
if src._transpose is None:
dst.transpose_2d(force_compute=True, fill_cache=True)
else:
dst._transpose.copy_(src._transpose)
dst._transpose_invalid = False
return self
# Convert QuantizedTensor to plain tensor
if isinstance(src, QuantizedTensor):
return dst.quantize_(src.dequantize())
# Make sure input is in expected format
if src.size() != dst.size():
src = src.expand(dst.size())
if not devices_match(src.device, dst.device):
src = src.to(device=dst.device)
if src.dtype not in (torch.float32, torch.bfloat16, torch.float16):
src = src.float()
if not src.is_contiguous():
src = src.contiguous()
# Make sure FP8 scaling factors are in expected format
if scale is not None:
if isinstance(scale, torch.Tensor):
if not devices_match(scale.device, dst.device) or scale.dtype != torch.float32:
scale = scale.to(device=dst.device, dtype=torch.float32)
else:
scale = torch.full([1], scale, dtype=torch.float32, device=dst.device)
if amax is not None:
while amax.dim() < 2:
amax = amax.unsqueeze(0)
if not devices_match(amax.device, dst.device):
raise ValueError(
f"Invalid device for amax (expected {dst.device}, found {amax.device})"
)
if amax.dtype != torch.float32:
raise ValueError(f"Invalid dtype for amax (expected float32, found {amax.type})")
# Default FP8 scaling factors
fp8_meta = None
if dst._fp8_meta is None:
if scale is None:
scale = dst._scale_inv.reciprocal()
if amax is None:
amax = torch.empty((1, 1), dtype=torch.float32, device=dst.device)
else:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=dst._fp8_meta_forward,
)
fp8_meta = dst._fp8_meta[fp8_meta_key]
# Check local data
if not dst._data.is_contiguous():
raise RuntimeError("Transformer Engine cast kernels require contiguous data")
# Perform FP8 cast
if dst._transpose is None:
dst_data = dst._data
if src.dim() != 2:
src = src.view(1, -1)
dst_data = dst_data.view(1, -1)
cast_to_fp8(
src,
fp8_meta,
dst._fp8_meta_index,
dst._fp8_dtype,
out=dst_data,
scale=scale,
amax=amax,
scale_inv=dst._scale_inv,
)
else:
fp8_cast_transpose_fused(
src.view(-1, src.size(-1)),
fp8_meta,
dst._fp8_meta_index,
dst._fp8_dtype,
cast_out=dst._data,
transpose_out=dst._transpose,
scale=scale,
amax=amax,
scale_inv=dst._scale_inv,
noop_flag=noop_flag,
)
dst._transpose_invalid = False
# Callback hook to perform amax reduction after optimizer step
post_optimizer_step_fwd_amax_reduction(self)
return self
@classmethod
def to_float8(
cls,
tensor: torch.Tensor,
*,
fp8_meta: Optional[Dict[str, Any]] = None,
fp8_meta_forward: bool = True,
fp8_meta_index: Optional[int] = None,
fp8_dtype: TE_DType = TE_DType.kFloat8E4M3,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
with_transpose_cache: bool = False,
):
"""Construct Float8Tensor from plain PyTorch tensor"""
return _ToFloat8Func.apply(
tensor,
fp8_meta,
fp8_meta_forward,
fp8_meta_index,
fp8_dtype,
scale,
amax,
scale_inv,
with_transpose_cache,
)
def detach(self) -> Float8Tensor:
return Float8Tensor.make_like(
self,
data=self._data,
fp8_attrs=self._fp8_attrs,
)
def clone(self) -> Float8Tensor:
data = self._data.detach().clone()
data_transpose = None
if self._transpose is not None:
data_transpose = self._transpose.detach().clone()
return _IdentityFunc.apply(
self,
dict(
data=data,
data_transpose=data_transpose,
),
)
def view(self, *shape: Tuple[int]) -> Float8Tensor:
return _ViewFunc.apply(self, shape)
def reshape(self, *shape: Tuple[int]) -> Float8Tensor:
return _ReshapeFunc.apply(self, shape)
def contiguous(
self,
*,
memory_format: torch.memory_format = torch.contiguous_format,
) -> Float8Tensor:
"""Returns tensor with data in provided memory format
Returns `self` if data is already in correct memory format.
"""
if self._data.is_contiguous(memory_format=memory_format):
return self
return _IdentityFunc.apply(
self,
{"data": self._data.detach().contiguous(memory_format=memory_format)},
)
def transpose_2d(
self,
*,
force_compute: bool = False,
fill_cache: bool = False,
noop_flag: Optional[torch.Tensor] = None,
cache: Optional[bool] = None,
) -> torch.Tensor:
"""
2D transpose with caching support.
Parameters
----------
force_compute: bool, default = `False`
Force computation of transpose. Otherwise use
cached values, if possible.
fill_cache: bool, default = `False`
Cache output tensor for future function calls.
noop_flag: torch.Tensor, optional
float32 flag indicating whether to avoid updating
cached values, if possible.
cache: bool, deprecated
"""
# Handle deprecated cache kwarg
if cache is not None:
msg = (
"cache kwarg for Float8Tensor.transpose_2d is deprecated, "
"please use force_compute and fill_cache instead"
)
warnings.warn(msg, DeprecationWarning)
if cache:
force_compute = False
fill_cache = True
else:
force_compute = True
fill_cache = False
# Need to compute transpose if cache is invalid
need_compute = force_compute
if self._transpose is None:
need_compute = True
elif self._transpose_invalid:
need_compute = True
# Need to apply transpose kernel if noop flag is applied
if noop_flag is not None:
need_compute = True
# Return cached transpose if possible
if not need_compute:
return self._transpose
# Allocate output if needed
data = self._data.contiguous().reshape(-1, self.size(-1))
out = self._transpose
if out is None:
out = torch.empty(
(data.size(1), data.size(0)),
dtype=torch.uint8,
device=data.device,
)
noop_flag = None
else:
self._transpose_invalid = False
# Apply transpose kernel
fp8_dtype = self._fp8_dtype
if noop_flag is None:
tex.fp8_transpose_noalloc(data, out, fp8_dtype)
else:
noop_flag = noop_flag.to(dtype=torch.float32, device=data.device)
tex.fp8_transpose_noalloc_noop(data, out, noop_flag, fp8_dtype)
# Fill cache if needed
if fill_cache:
self._transpose = out
self._transpose_invalid = False
return out
@torch.no_grad()
def cast_transpose_(
self,
tensor: torch.Tensor,
noop_flag: Optional[torch.Tensor] = None,
) -> None:
"""Cast from tensor and populate transpose cache
Tensor is reshaped as a 2D matrix.
Parameters
----------
tensor: torch.Tensor
Tensor to copy from. Must have same dimensions as
destination tensor.
noop_flag: torch.Tensor, optional
float32 flag indicating whether to avoid updating
destination tensor.
"""
if self._transpose is None:
self._transpose = torch.empty(
(self.size(-1), self.numel() // self.size(-1)),
dtype=torch.uint8,
device=self.device,
)
self.quantize_(tensor, noop_flag=noop_flag)
@torch.no_grad()
def reset_fp8_meta_scale_inv(self) -> None:
"""Replace FP8 meta tensor scale-inverse with cached value
The FP8 meta tensor scale_inv entry corresponding to this
tensor is replaced with the scale_inv value used to construct
the tensor.
"""
assert self._fp8_meta is not None, "FP8 meta tensors not found."
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=self._fp8_meta_forward,
)
self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index].copy_(self._scale_inv[0])
def to_dtype(self, dtype: torch.dtype) -> Float8Tensor:
"""Create `Float8Tensor` with given nominal dtype
The new tensor has the same underlying FP8 data.
"""
return Float8Tensor.make_like(
self,
data=self._data,
fp8_attrs=self._fp8_attrs,
dtype=dtype,
)
def _reset_caches(self) -> None:
"""
Set transpose cache as invalid.
Should be called after any in-place operation.
"""
self._transpose_invalid = True
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
# Slice op
if func == aten.slice.Tensor:
tensor = args[0]
data = tensor._data
data_slice = data.__torch_dispatch__(
func,
types,
[data] + list(args[1:]),
kwargs,
)
return Float8Tensor.make_like(tensor, data=data_slice)
# View op
if func == aten.view.default:
tensor = args[0]
data = tensor._data
data_view = data.__torch_dispatch__(
func,
types,
[data] + list(args[1:]),
kwargs,
)
return Float8Tensor.make_like(tensor, data=data_view)
# Default case
return super().__torch_dispatch__(func, types, args, kwargs)
@classmethod
def _make_in_reduce_ex(
cls,
data: torch.Tensor,
fp8_dtype: TE_DType,
fp8_scale_inv: torch.Tensor,
dtype: torch.dtype,
) -> Float8Tensor:
"""Build Float8Tensor, for use in __reduce__
__reduce_ex__ assumes object constructor has positional
arguments.
"""
return Float8Tensor(
data=data,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
dtype=dtype,
)
def __reduce_ex__(self, protocol: int) -> tuple:
"""Custom pickling to remove references to FP8 metadata objects"""
return (
Float8Tensor._make_in_reduce_ex,
(self._data, self._fp8_dtype, self._scale_inv, self.dtype),
)
def _get_data(self) -> Float8Tensor:
"""Get tensor data property"""
return super().data
def _set_data(self, tensor: torch.Tensor) -> None:
"""Set tensor data property
Cast tensor to FP8 and store in FP8 buffer.
"""
with torch.no_grad():
self.copy_(tensor)
# Cast to FP8 when setting Float8Tensor.data
data = property(_get_data, _set_data)
# Accessors for objects in self._fp8_attrs
# Note: We store FP8 attributes in a dictionary so we can share
# them between tensors with the same data, e.g. detached tensors.
# For convenience, we also expose them as property attributes.
_fp8_meta = property(**_make_fp8_attr_property_funcs("fp8_meta"))
_fp8_meta_forward = property(**_make_fp8_attr_property_funcs("fp8_meta_forward"))
_fp8_meta_index = property(**_make_fp8_attr_property_funcs("fp8_meta_index"))
_fp8_dtype = property(**_make_fp8_attr_property_funcs("dtype"))
_transpose = property(**_make_fp8_attr_property_funcs("transpose"))
_transpose_invalid = property(**_make_fp8_attr_property_funcs("transpose_invalid"))
_scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv"))
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tensor with quantized data"""
from __future__ import annotations
from typing import Optional, Tuple
import torch
from torch.utils._pytree import tree_map
class _DequantizeFunc(torch.autograd.Function):
"""Autograd function to convert quantized tensor to standard tensor"""
@staticmethod
def forward(
_ctx: torch.autograd.function.FunctionCtx, # unused
tensor: QuantizedTensor,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
return tensor.dequantize(dtype=dtype)
@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
return grad, None
class _IdentityFunc(torch.autograd.Function):
"""Autograd function to create quantized tensor with same data"""
@staticmethod
def forward(
_ctx: torch.autograd.function.FunctionCtx, # unused
tensor: QuantizedTensor,
) -> QuantizedTensor:
return tensor.detach()
@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> torch.Tensor:
return grad
class QuantizedTensor(torch.Tensor):
"""Abstract base class for tensor with quantized data
This is a proxy class with the interface of a standard PyTorch
tensor, but with data that has been encoded with some quantization
scheme. Derived classes should implement the quantization scheme
by overriding the `quantize_` and `dequantize` functions.
"""
def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""Convert quantized data to standard PyTorch tensor"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement dequantize function"
)
def quantize_(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Update quantized data in-place"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement quantize_ function"
)
def detach(self) -> QuantizedTensor:
"""Create new quantized tensor with same data
Output tensor must be detached from the current autograd
graph.
"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement detach function"
)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})"
def float(self) -> torch.Tensor:
return _DequantizeFunc.apply(self, torch.float32)
def bfloat16(self) -> torch.Tensor:
return _DequantizeFunc.apply(self, torch.bfloat16)
def half(self) -> torch.Tensor:
return _DequantizeFunc.apply(self, torch.float16)
def cpu(self) -> torch.Tensor:
return _DequantizeFunc.apply(self).cpu()
def expand_as(self, other: torch.Tensor) -> torch.Tensor:
if other is self:
# Note: expand_as is hackily used to create dummy autograd nodes
# and access the backward graph (see
# https://github.com/pytorch/pytorch/blob/238fb660851268f44ff88127887041fea352fe48/torch/nn/parallel/distributed.py#L1026).
# We hackily add a dummy function to handle this case.
return _IdentityFunc.apply(self)
return super().expand_as(other)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
# Detach op
if func == torch.ops.aten.detach.default:
return args[0].detach()
# In-place copy op
if func == torch.ops.aten.copy_.default:
dst = args[0]
src = args[1]
if isinstance(dst, QuantizedTensor):
dst.quantize_(src)
else:
if isinstance(src, QuantizedTensor):
src = src.dequantize()
dst.copy_(src)
return None
# View op
if func == torch.ops.aten.view.default:
raise NotImplementedError("{cls.__name__} class does not support tensor views")
def maybe_unwrap(arg):
if isinstance(arg, QuantizedTensor):
return arg.dequantize(dtype=arg.dtype)
return arg
def maybe_update_inplace(arg, new_arg, schema_arg):
if (
isinstance(arg, QuantizedTensor)
and isinstance(new_arg, torch.Tensor)
and hasattr(schema_arg, "alias_info")
and hasattr(schema_arg.alias_info, "is_write")
and schema_arg.alias_info.is_write
):
arg.quantize_(new_arg)
# In-place op: dequantize, perform op, and quantize
if func._schema.is_mutable:
new_args = tree_map(maybe_unwrap, args)
new_kwargs = tree_map(maybe_unwrap, kwargs)
schema_args = func._schema.arguments
args_len = len(args)
super().__torch_dispatch__(func, types, new_args, new_kwargs)
for arg, new_arg, schema_arg in zip(args, new_args, schema_args):
maybe_update_inplace(arg, new_arg, schema_arg)
for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]):
assert kwarg == new_kwarg == schema_arg.name, "name of the kw argument should match"
maybe_update_inplace(kwargs[kwarg], new_kwargs[new_kwarg], schema_arg)
return None
# Default op: dequantize and perform op
args = tree_map(maybe_unwrap, args)
if kwargs is not None:
kwargs = tree_map(maybe_unwrap, kwargs)
out = super().__torch_dispatch__(func, types, args, kwargs)
return out
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
# Do not force the QuantizedTensor type on the returned tensor
return torch._C._disabled_torch_function_impl(func, types, args, kwargs)
......@@ -3,6 +3,7 @@
# See LICENSE for license information.
"""Utility functions for Transformer Engine modules"""
from __future__ import annotations
import functools
import math
from typing import Any, Callable, Optional, Tuple
......@@ -251,3 +252,52 @@ def get_cudnn_version() -> Tuple[int, int, int]:
major, encoded_version = divmod(encoded_version, major_version_magnitude)
minor, patch = divmod(encoded_version, 100)
return (major, minor, patch)
def canonicalize_device(device: Optional[torch.device | str]) -> torch.device:
"""Canonicalize PyTorch device
If `None`, then returns the default CUDA device.
"""
if device is None:
# Use default CUDA device
device = torch.get_default_device()
if device.type != "cuda":
device = torch.device("cuda", torch.cuda.current_device())
elif not isinstance(device, torch.device):
device = torch.device(device)
if device.type == "cuda" and device.index is None:
device = torch.device("cuda", torch.cuda.current_device())
return device
def canonicalize_dtype(dtype: Optional[torch.dtype]) -> torch.dtype:
"""Canonicalize PyTorch datatype
If `None`, then returns the default PyTorch datatype.
"""
if dtype is None:
# Use default dtype
dtype = torch.get_default_dtype()
return dtype
def devices_match(device1: torch.device, device2: torch.device) -> bool:
"""Whether two devices are the same"""
device1 = torch.device(device1)
device2 = torch.device(device2)
if device1.type != device2.type:
return False
if device1.type == "cuda":
index1 = device1.index
index2 = device2.index
if index1 == index2:
return True
if index1 is None:
index1 = torch.cuda.current_device()
if index2 is None:
index2 = torch.cuda.current_device()
return index1 == index2
return device1 == device2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment