Unverified Commit 1e780946 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Add Float8Tensor option to avoid updating transpose cache when possible (#662)



* Add option to avoid updating transpose cache when possible
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix typo
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use string kwarg for FP8 transpose caching
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove unused attr
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent bdf1afee
......@@ -298,22 +298,44 @@ class TestFloat8Tensor:
# Check transpose caching
if x_fp8.dim() == 2 and transpose_dims[0] != transpose_dims[1]:
# Check that cached transpose is returned when expected
# Note: Sneakily destroy data so that recalculating
# transpose would give wrong answer.
x_fp8 += 0.5
x_ref = x_fp8.from_float8()
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache=True),
x_fp8.transpose(*transpose_dims, update_cache="lazy"),
x_ref.transpose(*transpose_dims),
**tols,
)
x_fp8_data = x_fp8._data.clone()
x_fp8._data.zero_()
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims),
x_ref.transpose(*transpose_dims),
**tols,
)
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache=True),
x_fp8.transpose(*transpose_dims, update_cache="lazy"),
x_ref.transpose(*transpose_dims),
**tols,
)
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache="force"),
torch.zeros_like(x_ref.transpose(*transpose_dims)),
rtol=0,
atol=0,
)
x_fp8._data.copy_(x_fp8_data)
x_fp8._reset_caches()
# Make sure cache is reset after in-place operation
x_fp8.transpose(*transpose_dims, update_cache="force")
x_fp8 += 0.5
x_ref = x_fp8.from_float8()
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache=True),
x_fp8.transpose(*transpose_dims),
x_ref.transpose(*transpose_dims),
**tols,
)
......
......@@ -440,7 +440,7 @@ class Float8Tensor(torch.Tensor):
dim0: int = 0,
dim1: int = 1,
*,
update_cache: bool = False,
update_cache: str | bool = "reuse_only",
) -> torch.Tensor:
"""
Swap tensor dimensions
......@@ -454,23 +454,36 @@ class Float8Tensor(torch.Tensor):
The first dimension to be transposed
dim1: int, default = 1
The second dimension to be transposed
update_cache: bool, default = False
If `True`, the transpose is computed and stored
in a cache. If `False`, a cached version is
returned if available and otherwise the
transpose is computed. Caching is only supported
update_cache: str or bool, default = "reuse_only"
Memoization behavior. Options are
"reuse_only"/`False` (reuse cached value if
available, otherwise calculate transpose without
caching), "force"/`True` (calculate transpose
and cache), "lazy" (reuse cached value if
available, otherwise calculate transpose and
cache if possible). Caching is only supported
for basic 2D transposes and the cache is reset
after any in-place operations.
"""
# Check caching mode
if not isinstance(update_cache, str):
update_cache = "force" if update_cache else "reuse_only"
if update_cache not in ("force", "reuse_only", "lazy"):
raise ValueError(
"Supported values for update_cache are "
'"force" (True), "reuse_only" (False), "lazy" '
f"(got {update_cache})"
)
# Handle non-2D transposes
if -self.dim() <= dim0 < 0:
dim0 += self.dim()
if -self.dim() <= dim1 < 0:
dim1 += self.dim()
if self.dim() != 2 or dim0 == dim1:
if update_cache:
if update_cache == "force":
raise ValueError(
"Transpose caching is only supported for basic 2D transposes "
f"(ndims={self.dim()}, dim0={dim0}, dim1={dim1})"
......@@ -478,7 +491,7 @@ class Float8Tensor(torch.Tensor):
return super().transpose(dim0, dim1)
# Clear cache if needed
if update_cache:
if update_cache == "force":
self._transpose = None
# Compute transpose if needed
......@@ -493,7 +506,7 @@ class Float8Tensor(torch.Tensor):
)
# Update cache if needed
if update_cache:
if update_cache in ("force", "lazy"):
self._transpose = out
return out
......
......@@ -331,7 +331,9 @@ class _LayerNormLinear(torch.autograd.Function):
# Primary weights are in FP8.
if ctx.fp8 and weight_t_fp8 is None:
weight_t_fp8 = weight.transpose(update_cache=ctx.is_first_microbatch)
weight_t_fp8 = weight.transpose(
update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy",
)
if ctx.ub_bulk_dgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group)
......
......@@ -560,10 +560,11 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_weight.main_grad = fc2_weight_main_grad
# Primary weights are in FP8.
update_transpose_cache = "reuse_only" if ctx.is_first_microbatch is None else "lazy"
if ctx.fp8 and fc1_weight_t_fp8 is None:
fc1_weight_t_fp8 = fc1_weight.transpose(update_cache=ctx.is_first_microbatch)
fc1_weight_t_fp8 = fc1_weight.transpose(update_cache=update_transpose_cache)
if ctx.fp8 and fc2_weight_t_fp8 is None:
fc2_weight_t_fp8 = fc2_weight.transpose(update_cache=ctx.is_first_microbatch)
fc2_weight_t_fp8 = fc2_weight.transpose(update_cache=update_transpose_cache)
activation_func = _act_func(ctx.activation)[1]
......
......@@ -347,7 +347,9 @@ class _Linear(torch.autograd.Function):
# Primary weights are in FP8.
if ctx.fp8 and weight_t_fp8 is None:
weight_t_fp8 = weight.transpose(update_cache=ctx.is_first_microbatch)
weight_t_fp8 = weight.transpose(
update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy",
)
if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag:
tp_world_size = get_distributed_world_size(ctx.tp_group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment