Unverified Commit 1e780946 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Add Float8Tensor option to avoid updating transpose cache when possible (#662)



* Add option to avoid updating transpose cache when possible
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix typo
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use string kwarg for FP8 transpose caching
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove unused attr
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent bdf1afee
...@@ -298,22 +298,44 @@ class TestFloat8Tensor: ...@@ -298,22 +298,44 @@ class TestFloat8Tensor:
# Check transpose caching # Check transpose caching
if x_fp8.dim() == 2 and transpose_dims[0] != transpose_dims[1]: if x_fp8.dim() == 2 and transpose_dims[0] != transpose_dims[1]:
# Check that cached transpose is returned when expected
# Note: Sneakily destroy data so that recalculating
# transpose would give wrong answer.
x_fp8 += 0.5 x_fp8 += 0.5
x_ref = x_fp8.from_float8() x_ref = x_fp8.from_float8()
torch.testing.assert_close( torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache=True), x_fp8.transpose(*transpose_dims, update_cache="lazy"),
x_ref.transpose(*transpose_dims), x_ref.transpose(*transpose_dims),
**tols, **tols,
) )
x_fp8_data = x_fp8._data.clone()
x_fp8._data.zero_()
torch.testing.assert_close( torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache=True), x_fp8.transpose(*transpose_dims),
x_ref.transpose(*transpose_dims), x_ref.transpose(*transpose_dims),
**tols, **tols,
) )
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache="lazy"),
x_ref.transpose(*transpose_dims),
**tols,
)
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache="force"),
torch.zeros_like(x_ref.transpose(*transpose_dims)),
rtol=0,
atol=0,
)
x_fp8._data.copy_(x_fp8_data)
x_fp8._reset_caches()
# Make sure cache is reset after in-place operation
x_fp8.transpose(*transpose_dims, update_cache="force")
x_fp8 += 0.5 x_fp8 += 0.5
x_ref = x_fp8.from_float8() x_ref = x_fp8.from_float8()
torch.testing.assert_close( torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache=True), x_fp8.transpose(*transpose_dims),
x_ref.transpose(*transpose_dims), x_ref.transpose(*transpose_dims),
**tols, **tols,
) )
......
...@@ -440,7 +440,7 @@ class Float8Tensor(torch.Tensor): ...@@ -440,7 +440,7 @@ class Float8Tensor(torch.Tensor):
dim0: int = 0, dim0: int = 0,
dim1: int = 1, dim1: int = 1,
*, *,
update_cache: bool = False, update_cache: str | bool = "reuse_only",
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Swap tensor dimensions Swap tensor dimensions
...@@ -454,23 +454,36 @@ class Float8Tensor(torch.Tensor): ...@@ -454,23 +454,36 @@ class Float8Tensor(torch.Tensor):
The first dimension to be transposed The first dimension to be transposed
dim1: int, default = 1 dim1: int, default = 1
The second dimension to be transposed The second dimension to be transposed
update_cache: bool, default = False update_cache: str or bool, default = "reuse_only"
If `True`, the transpose is computed and stored Memoization behavior. Options are
in a cache. If `False`, a cached version is "reuse_only"/`False` (reuse cached value if
returned if available and otherwise the available, otherwise calculate transpose without
transpose is computed. Caching is only supported caching), "force"/`True` (calculate transpose
and cache), "lazy" (reuse cached value if
available, otherwise calculate transpose and
cache if possible). Caching is only supported
for basic 2D transposes and the cache is reset for basic 2D transposes and the cache is reset
after any in-place operations. after any in-place operations.
""" """
# Check caching mode
if not isinstance(update_cache, str):
update_cache = "force" if update_cache else "reuse_only"
if update_cache not in ("force", "reuse_only", "lazy"):
raise ValueError(
"Supported values for update_cache are "
'"force" (True), "reuse_only" (False), "lazy" '
f"(got {update_cache})"
)
# Handle non-2D transposes # Handle non-2D transposes
if -self.dim() <= dim0 < 0: if -self.dim() <= dim0 < 0:
dim0 += self.dim() dim0 += self.dim()
if -self.dim() <= dim1 < 0: if -self.dim() <= dim1 < 0:
dim1 += self.dim() dim1 += self.dim()
if self.dim() != 2 or dim0 == dim1: if self.dim() != 2 or dim0 == dim1:
if update_cache: if update_cache == "force":
raise ValueError( raise ValueError(
"Transpose caching is only supported for basic 2D transposes " "Transpose caching is only supported for basic 2D transposes "
f"(ndims={self.dim()}, dim0={dim0}, dim1={dim1})" f"(ndims={self.dim()}, dim0={dim0}, dim1={dim1})"
...@@ -478,7 +491,7 @@ class Float8Tensor(torch.Tensor): ...@@ -478,7 +491,7 @@ class Float8Tensor(torch.Tensor):
return super().transpose(dim0, dim1) return super().transpose(dim0, dim1)
# Clear cache if needed # Clear cache if needed
if update_cache: if update_cache == "force":
self._transpose = None self._transpose = None
# Compute transpose if needed # Compute transpose if needed
...@@ -493,7 +506,7 @@ class Float8Tensor(torch.Tensor): ...@@ -493,7 +506,7 @@ class Float8Tensor(torch.Tensor):
) )
# Update cache if needed # Update cache if needed
if update_cache: if update_cache in ("force", "lazy"):
self._transpose = out self._transpose = out
return out return out
......
...@@ -331,7 +331,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -331,7 +331,9 @@ class _LayerNormLinear(torch.autograd.Function):
# Primary weights are in FP8. # Primary weights are in FP8.
if ctx.fp8 and weight_t_fp8 is None: if ctx.fp8 and weight_t_fp8 is None:
weight_t_fp8 = weight.transpose(update_cache=ctx.is_first_microbatch) weight_t_fp8 = weight.transpose(
update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy",
)
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group) tp_world_size = get_distributed_world_size(ctx.tp_group)
......
...@@ -560,10 +560,11 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -560,10 +560,11 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_weight.main_grad = fc2_weight_main_grad fc2_weight.main_grad = fc2_weight_main_grad
# Primary weights are in FP8. # Primary weights are in FP8.
update_transpose_cache = "reuse_only" if ctx.is_first_microbatch is None else "lazy"
if ctx.fp8 and fc1_weight_t_fp8 is None: if ctx.fp8 and fc1_weight_t_fp8 is None:
fc1_weight_t_fp8 = fc1_weight.transpose(update_cache=ctx.is_first_microbatch) fc1_weight_t_fp8 = fc1_weight.transpose(update_cache=update_transpose_cache)
if ctx.fp8 and fc2_weight_t_fp8 is None: if ctx.fp8 and fc2_weight_t_fp8 is None:
fc2_weight_t_fp8 = fc2_weight.transpose(update_cache=ctx.is_first_microbatch) fc2_weight_t_fp8 = fc2_weight.transpose(update_cache=update_transpose_cache)
activation_func = _act_func(ctx.activation)[1] activation_func = _act_func(ctx.activation)[1]
......
...@@ -347,7 +347,9 @@ class _Linear(torch.autograd.Function): ...@@ -347,7 +347,9 @@ class _Linear(torch.autograd.Function):
# Primary weights are in FP8. # Primary weights are in FP8.
if ctx.fp8 and weight_t_fp8 is None: if ctx.fp8 and weight_t_fp8 is None:
weight_t_fp8 = weight.transpose(update_cache=ctx.is_first_microbatch) weight_t_fp8 = weight.transpose(
update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy",
)
if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag: if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag:
tp_world_size = get_distributed_world_size(ctx.tp_group) tp_world_size = get_distributed_world_size(ctx.tp_group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment