"...targets/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "5315d9bb2672a5db993e769636737f9a4ed233ce"
Unverified Commit a488b8b1 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

Fix seq_dim in CP implementation (#1264)



fix seq_dim in CP implementation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
parent 12f30ead
......@@ -2534,6 +2534,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
causal = "causal" in ctx.attn_mask_type
padding = "padding" in ctx.attn_mask_type
seq_dim = None
if ctx.qkv_format in ["bshd", "sbhd"]:
seq_dim = ctx.qkv_format.index("s")
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:]
......@@ -2580,7 +2582,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fused_attn_qkv_dtype = None
fused_attn_dqkv_dtype = None
amax_per_step = None
seq_dim = None
dout_fp8_dtype = None
if ctx.fp8:
if ctx.use_fused_attention:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment