Unverified Commit d1488e73 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Fix multiple calls to saved_tensors in CP attention (#1334)



* Limit to one call of ctx.saved_tensors per autograd bwd
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 28aa41a3
...@@ -2528,12 +2528,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2528,12 +2528,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)
(q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6] (*saved_tensors,) = ctx.saved_tensors
(fp8_fwd_scales, fp8_fwd_scale_invs) = ctx.saved_tensors[6:8] (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = saved_tensors[:6]
cu_seqlens_q_per_step = ctx.saved_tensors[8 : 8 + cp_size] (fp8_fwd_scales, fp8_fwd_scale_invs) = saved_tensors[6:8]
cu_seqlens_kv_per_step = ctx.saved_tensors[8 + cp_size : 8 + cp_size * 2] cu_seqlens_q_per_step = saved_tensors[8 : 8 + cp_size]
rng_states = ctx.saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3] cu_seqlens_kv_per_step = saved_tensors[8 + cp_size : 8 + cp_size * 2]
attn_biases = ctx.saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4] rng_states = saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3]
attn_biases = saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4]
causal = "causal" in ctx.attn_mask_type causal = "causal" in ctx.attn_mask_type
padding = "padding" in ctx.attn_mask_type padding = "padding" in ctx.attn_mask_type
...@@ -3577,11 +3578,12 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -3577,11 +3578,12 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
cp_size = get_distributed_world_size(ctx.cp_group) cp_size = get_distributed_world_size(ctx.cp_group)
rank = get_distributed_rank(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group)
(q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = ctx.saved_tensors[:5] (*saved_tensors,) = ctx.saved_tensors
cu_seqlens_kv_per_step = ctx.saved_tensors[5:7] (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = saved_tensors[:5]
out_per_step = ctx.saved_tensors[7:9] cu_seqlens_kv_per_step = saved_tensors[5:7]
softmax_lse_per_step = ctx.saved_tensors[9:11] out_per_step = saved_tensors[7:9]
rng_states = ctx.saved_tensors[11:13] softmax_lse_per_step = saved_tensors[9:11]
rng_states = saved_tensors[11:13]
kv_seq_range_per_step = ctx.kv_seq_range_per_step kv_seq_range_per_step = ctx.kv_seq_range_per_step
window_size_per_step = ctx.window_size_per_step window_size_per_step = ctx.window_size_per_step
...@@ -4056,12 +4058,11 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -4056,12 +4058,11 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
cp_size = get_distributed_world_size(ctx.cp_group) cp_size = get_distributed_world_size(ctx.cp_group)
q, k, v, out = ctx.saved_tensors[:4] (*saved_tensors,) = ctx.saved_tensors
cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = ctx.saved_tensors[ q, k, v, out = saved_tensors[:4]
4:8 cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = saved_tensors[4:8]
] fp8_fwd_scales, fp8_fwd_scale_invs = saved_tensors[8:10]
fp8_fwd_scales, fp8_fwd_scale_invs = ctx.saved_tensors[8:10] aux_ctx_tensors = saved_tensors[10:]
aux_ctx_tensors = ctx.saved_tensors[10:]
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
causal = "causal" in ctx.attn_mask_type causal = "causal" in ctx.attn_mask_type
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment