Unverified Commit 15088217 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] FP8 Tensor improvements (#500)



* Delay caching of transposes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review comment
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e6676c53
...@@ -168,9 +168,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -168,9 +168,6 @@ class _LayerNormLinear(torch.autograd.Function):
weight.reset_fp8_meta_scale_inv() weight.reset_fp8_meta_scale_inv()
weight_fp8 = weight weight_fp8 = weight
weight_t_fp8 = None weight_t_fp8 = None
if is_grad_enabled:
weight_t_fp8 = weight_fp8.transpose(update_cache=is_first_microbatch)
elif update_fp8_weights: elif update_fp8_weights:
# Need to cast weights to FP8 # Need to cast weights to FP8
weight_fp8 = Float8Tensor( weight_fp8 = Float8Tensor(
...@@ -306,6 +303,10 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -306,6 +303,10 @@ class _LayerNormLinear(torch.autograd.Function):
fwd_scale_inverses, fwd_scale_inverses,
) = ctx.saved_tensors ) = ctx.saved_tensors
# Primary weights are in FP8.
if ctx.fp8 and weight_t_fp8 is None:
weight_t_fp8 = weight.transpose(update_cache=ctx.is_first_microbatch)
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group) tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1: if tp_world_size == 1:
......
...@@ -207,10 +207,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -207,10 +207,6 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_weight_fp8 = fc2_weight fc2_weight_fp8 = fc2_weight
fc1_weight_t_fp8 = None fc1_weight_t_fp8 = None
fc2_weight_t_fp8 = None fc2_weight_t_fp8 = None
if is_grad_enabled:
fc1_weight_t_fp8 = fc1_weight_fp8.transpose(update_cache=is_first_microbatch)
fc2_weight_t_fp8 = fc2_weight_fp8.transpose(update_cache=is_first_microbatch)
elif update_fp8_weights: elif update_fp8_weights:
# Need to cast weights to FP8 # Need to cast weights to FP8
fc1_weight_fp8 = Float8Tensor( fc1_weight_fp8 = Float8Tensor(
...@@ -499,6 +495,12 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -499,6 +495,12 @@ class _LayerNormMLP(torch.autograd.Function):
fwd_scale_inverses, fwd_scale_inverses,
) = ctx.saved_tensors ) = ctx.saved_tensors
# Primary weights are in FP8.
if ctx.fp8 and fc1_weight_t_fp8 is None:
fc1_weight_t_fp8 = fc1_weight.transpose(update_cache=ctx.is_first_microbatch)
if ctx.fp8 and fc2_weight_t_fp8 is None:
fc2_weight_t_fp8 = fc2_weight.transpose(update_cache=ctx.is_first_microbatch)
activation_func = _act_func(ctx.activation)[1] activation_func = _act_func(ctx.activation)[1]
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
......
...@@ -151,9 +151,6 @@ class _Linear(torch.autograd.Function): ...@@ -151,9 +151,6 @@ class _Linear(torch.autograd.Function):
weight.reset_fp8_meta_scale_inv() weight.reset_fp8_meta_scale_inv()
weight_fp8 = weight weight_fp8 = weight
weight_t_fp8 = None weight_t_fp8 = None
if is_grad_enabled:
weight_t_fp8 = weight_fp8.transpose(update_cache=is_first_microbatch)
elif update_fp8_weights: elif update_fp8_weights:
# Need to cast weights to FP8 # Need to cast weights to FP8
weight_fp8 = Float8Tensor( weight_fp8 = Float8Tensor(
...@@ -264,9 +261,6 @@ class _Linear(torch.autograd.Function): ...@@ -264,9 +261,6 @@ class _Linear(torch.autograd.Function):
if is_grad_enabled: if is_grad_enabled:
fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad
if fp8:
assert hasattr(weight_t_fp8, "_data"), \
"_data attr doesn't exist (before save for bwd)"
ctx.save_for_backward( ctx.save_for_backward(
inputmat_no_fp8 if weight.requires_grad and not fp8_wgrad else None, inputmat_no_fp8 if weight.requires_grad and not fp8_wgrad else None,
inputmat_t if weight.requires_grad and fp8_wgrad else None, inputmat_t if weight.requires_grad and fp8_wgrad else None,
...@@ -317,9 +311,10 @@ class _Linear(torch.autograd.Function): ...@@ -317,9 +311,10 @@ class _Linear(torch.autograd.Function):
weight_t_fp8, weight_t_fp8,
fwd_scale_inverses, fwd_scale_inverses,
) = ctx.saved_tensors ) = ctx.saved_tensors
if weight_t_fp8 is not None:
assert hasattr(weight_t_fp8, "_data"), \ # Primary weights are in FP8.
"_data attr doesn't exist (after restore in bwd)" if ctx.fp8 and weight_t_fp8 is None:
weight_t_fp8 = weight.transpose(update_cache=ctx.is_first_microbatch)
if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag: if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag:
tp_world_size = get_distributed_world_size(ctx.tp_group) tp_world_size = get_distributed_world_size(ctx.tp_group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment