Unverified Commit d9eb0582 authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

[PyTorch] Added attention activation offloading support for TE v2.0 (#1671)



* Added attention activation offloading support for TE v2.0
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche02.ptyche.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche02.ptyche.clusters.nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche02.ptyche.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
parent c638c436
......@@ -80,6 +80,7 @@ import transformer_engine.pytorch.dot_product_attention.utils as dpa_utils
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils
from transformer_engine.pytorch.dot_product_attention.utils import AttentionLogging as attn_log
from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_pos_emb
from .cpu_offload import set_offloading_param
# Setup Attention Logging
......@@ -4324,7 +4325,7 @@ class FlashAttention(torch.nn.Module):
tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv]
for tensor in tensor_list:
if tensor is not None:
tensor.activation_offloading = True
set_offloading_param(tensor, "activation_offloading", True)
with self.attention_dropout_ctx():
# | API | use cases
......@@ -4726,12 +4727,14 @@ class FusedAttnFunc(torch.autograd.Function):
else:
tensor_list = [q, k, v, out_save]
tensor_list.extend(aux_ctx_tensors)
qkv_layout = "sbhd_sbhd_sbhd"
for tensor in tensor_list:
if tensor is not None:
tensor.activation_offloading = True
set_offloading_param(tensor, "activation_offloading", True)
for tensor in aux_ctx_tensors:
if tensor is not None:
set_offloading_param(tensor, "activation_offloading", True)
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment