Unverified Commit a488b8b1 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

Fix seq_dim in CP implementation (#1264)



fix seq_dim in CP implementation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
parent 12f30ead
...@@ -2534,6 +2534,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2534,6 +2534,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
causal = "causal" in ctx.attn_mask_type causal = "causal" in ctx.attn_mask_type
padding = "padding" in ctx.attn_mask_type padding = "padding" in ctx.attn_mask_type
seq_dim = None
if ctx.qkv_format in ["bshd", "sbhd"]: if ctx.qkv_format in ["bshd", "sbhd"]:
seq_dim = ctx.qkv_format.index("s") seq_dim = ctx.qkv_format.index("s")
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:] qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:]
...@@ -2580,7 +2582,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2580,7 +2582,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fused_attn_qkv_dtype = None fused_attn_qkv_dtype = None
fused_attn_dqkv_dtype = None fused_attn_dqkv_dtype = None
amax_per_step = None amax_per_step = None
seq_dim = None
dout_fp8_dtype = None dout_fp8_dtype = None
if ctx.fp8: if ctx.fp8:
if ctx.use_fused_attention: if ctx.use_fused_attention:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment