Unverified Commit 15088217 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] FP8 Tensor improvements (#500)



* Delay caching of transposes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review comment
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e6676c53
......@@ -168,9 +168,6 @@ class _LayerNormLinear(torch.autograd.Function):
weight.reset_fp8_meta_scale_inv()
weight_fp8 = weight
weight_t_fp8 = None
if is_grad_enabled:
weight_t_fp8 = weight_fp8.transpose(update_cache=is_first_microbatch)
elif update_fp8_weights:
# Need to cast weights to FP8
weight_fp8 = Float8Tensor(
......@@ -306,6 +303,10 @@ class _LayerNormLinear(torch.autograd.Function):
fwd_scale_inverses,
) = ctx.saved_tensors
# Primary weights are in FP8.
if ctx.fp8 and weight_t_fp8 is None:
weight_t_fp8 = weight.transpose(update_cache=ctx.is_first_microbatch)
if ctx.ub_bulk_dgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
......
......@@ -207,10 +207,6 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_weight_fp8 = fc2_weight
fc1_weight_t_fp8 = None
fc2_weight_t_fp8 = None
if is_grad_enabled:
fc1_weight_t_fp8 = fc1_weight_fp8.transpose(update_cache=is_first_microbatch)
fc2_weight_t_fp8 = fc2_weight_fp8.transpose(update_cache=is_first_microbatch)
elif update_fp8_weights:
# Need to cast weights to FP8
fc1_weight_fp8 = Float8Tensor(
......@@ -499,6 +495,12 @@ class _LayerNormMLP(torch.autograd.Function):
fwd_scale_inverses,
) = ctx.saved_tensors
# Primary weights are in FP8.
if ctx.fp8 and fc1_weight_t_fp8 is None:
fc1_weight_t_fp8 = fc1_weight.transpose(update_cache=ctx.is_first_microbatch)
if ctx.fp8 and fc2_weight_t_fp8 is None:
fc2_weight_t_fp8 = fc2_weight.transpose(update_cache=ctx.is_first_microbatch)
activation_func = _act_func(ctx.activation)[1]
if ctx.ub_bulk_dgrad:
......
......@@ -151,9 +151,6 @@ class _Linear(torch.autograd.Function):
weight.reset_fp8_meta_scale_inv()
weight_fp8 = weight
weight_t_fp8 = None
if is_grad_enabled:
weight_t_fp8 = weight_fp8.transpose(update_cache=is_first_microbatch)
elif update_fp8_weights:
# Need to cast weights to FP8
weight_fp8 = Float8Tensor(
......@@ -264,9 +261,6 @@ class _Linear(torch.autograd.Function):
if is_grad_enabled:
fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad
if fp8:
assert hasattr(weight_t_fp8, "_data"), \
"_data attr doesn't exist (before save for bwd)"
ctx.save_for_backward(
inputmat_no_fp8 if weight.requires_grad and not fp8_wgrad else None,
inputmat_t if weight.requires_grad and fp8_wgrad else None,
......@@ -317,9 +311,10 @@ class _Linear(torch.autograd.Function):
weight_t_fp8,
fwd_scale_inverses,
) = ctx.saved_tensors
if weight_t_fp8 is not None:
assert hasattr(weight_t_fp8, "_data"), \
"_data attr doesn't exist (after restore in bwd)"
# Primary weights are in FP8.
if ctx.fp8 and weight_t_fp8 is None:
weight_t_fp8 = weight.transpose(update_cache=ctx.is_first_microbatch)
if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag:
tp_world_size = get_distributed_world_size(ctx.tp_group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment