Unverified Commit 454e3895 authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

Added offloading support FP8 attention (#1131)



* Added offloading support FP8 attention
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>

* Update transformer_engine/pytorch/attention.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSelvaraj Anandaraj <anandaraj@wisc.edu>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>
Signed-off-by: default avatarSelvaraj Anandaraj <anandaraj@wisc.edu>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 5fafeb0e
...@@ -5698,16 +5698,23 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -5698,16 +5698,23 @@ class FusedAttnFunc(torch.autograd.Function):
out_save = out_ret out_save = out_ret
fp8_tensors = (None, None, None, None, None, None) fp8_tensors = (None, None, None, None, None, None)
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
from .cpu_offload import CPUOffloadEnabled from .cpu_offload import CPUOffloadEnabled
if CPUOffloadEnabled: if CPUOffloadEnabled:
tensor_list = [q, k, v, out_save, cu_seqlens_q, cu_seqlens_kv] if ctx.fp8:
tensor_list = fp8_tensors
else:
tensor_list = [q, k, v, out_save]
tensor_list.extend(aux_ctx_tensors)
qkv_layout = "sbhd_sbhd_sbhd" qkv_layout = "sbhd_sbhd_sbhd"
for tensor in tensor_list: for tensor in tensor_list:
if tensor is not None: if tensor is not None:
tensor.activation_offloading = True tensor.activation_offloading = True
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
ctx.is_input_fp8 = is_input_fp8 ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8 ctx.is_output_fp8 = is_output_fp8
qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None) qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment