Commit 14b4d6bb authored by Jim's avatar Jim
Browse files

fix fp16 case

parent 9f24a7ed
...@@ -800,11 +800,9 @@ class FmhaBwdApiPool: ...@@ -800,11 +800,9 @@ class FmhaBwdApiPool:
v3_code = str() v3_code = str()
for i, dtype in enumerate(self.dq_dk_dv_v3_pool.keys()): for i, dtype in enumerate(self.dq_dk_dv_v3_pool.keys()):
per_bf16 = str() per_bf16_cvt = str()
for j, bf16_cvt in enumerate([0, 1, 2]): for j, bf16_cvt in enumerate([0, 1, 2]):
per_mask = str() per_mask = str()
if (dtype == "fp16") and (bf16_cvt in [1, 2]):
continue
for k, is_causal in enumerate(["t", "f"]): for k, is_causal in enumerate(["t", "f"]):
per_atomic = str() per_atomic = str()
for l, is_atomic in enumerate(["t", "f"]): for l, is_atomic in enumerate(["t", "f"]):
...@@ -812,10 +810,11 @@ class FmhaBwdApiPool: ...@@ -812,10 +810,11 @@ class FmhaBwdApiPool:
for m, hdim in enumerate(BWD_V3_HDIM_CASE_CHECK_MAP.values()): for m, hdim in enumerate(BWD_V3_HDIM_CASE_CHECK_MAP.values()):
if_m = 'if' if m == 0 else 'else if' if_m = 'if' if m == 0 else 'else if'
inners = str() inners = str()
bf16_cvt_tmp = 0 if dtype == "fp16" else bf16_cvt
if is_atomic == "t": if is_atomic == "t":
inners = FMHA_BWD_V3_ATOMIC32_INNER_DISPATCH.format(F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], F_is_causal=BOOL_MAP[is_causal], F_is_atomic32=BOOL_MAP[is_atomic], F_how_v3_bf16_cvt=bf16_cvt, F_padding=BWD_V3_PADDING_CHECK_MAP[m]) inners = FMHA_BWD_V3_ATOMIC32_INNER_DISPATCH.format(F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], F_is_causal=BOOL_MAP[is_causal], F_is_atomic32=BOOL_MAP[is_atomic], F_how_v3_bf16_cvt=bf16_cvt_tmp, F_padding=BWD_V3_PADDING_CHECK_MAP[m])
else: else:
inners = FMHA_BWD_V3_ATOMIC16_INNER_DISPATCH.format(F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], F_is_causal=BOOL_MAP[is_causal], F_is_atomic32=BOOL_MAP[is_atomic], F_how_v3_bf16_cvt=bf16_cvt, F_padding=BWD_V3_PADDING_CHECK_MAP[m]) inners = FMHA_BWD_V3_ATOMIC16_INNER_DISPATCH.format(F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], F_is_causal=BOOL_MAP[is_causal], F_is_atomic32=BOOL_MAP[is_atomic], F_how_v3_bf16_cvt=bf16_cvt_tmp, F_padding=BWD_V3_PADDING_CHECK_MAP[m])
per_hdim = per_hdim + FMHA_BWD_V3_PER_HDIM_CASE.format(F_if=if_m, F_hdim_expression=BWD_V3_HDIM_CASE_MAP[m], inner_dispatch=inners) per_hdim = per_hdim + FMHA_BWD_V3_PER_HDIM_CASE.format(F_if=if_m, F_hdim_expression=BWD_V3_HDIM_CASE_MAP[m], inner_dispatch=inners)
if_l = 'if' if l == 0 else 'else if' if_l = 'if' if l == 0 else 'else if'
...@@ -823,9 +822,9 @@ class FmhaBwdApiPool: ...@@ -823,9 +822,9 @@ class FmhaBwdApiPool:
if_k = 'if' if k == 0 else 'else if' if_k = 'if' if k == 0 else 'else if'
per_mask = per_mask + FMHA_BWD_V3_PER_MASK_CASE.format(F_if=if_k, F_mask_expression=BWD_V3_MASK_MAP[is_causal], per_atomic_dispatch=per_atomic) per_mask = per_mask + FMHA_BWD_V3_PER_MASK_CASE.format(F_if=if_k, F_mask_expression=BWD_V3_MASK_MAP[is_causal], per_atomic_dispatch=per_atomic)
if_j = 'if' if j == 0 else 'else if' if_j = 'if' if j == 0 else 'else if'
per_bf16 = per_bf16 + FMHA_BWD_V3_PER_BF16_CVT_CASE.format(F_if=if_j, F_bf16_cvt=bf16_cvt, per_mask_dispatch=per_mask) per_bf16_cvt = per_bf16_cvt + FMHA_BWD_V3_PER_BF16_CVT_CASE.format(F_if=if_j, F_bf16_cvt=bf16_cvt, per_mask_dispatch=per_mask)
if_i = 'if' if i == 0 else 'else if' if_i = 'if' if i == 0 else 'else if'
v3_code = v3_code + FMHA_BWD_V3_PER_DTYPE_CASE.format(F_if=if_i, F_dtype=dtype, per_bf16_cvt_dispatch=per_bf16) v3_code = v3_code + FMHA_BWD_V3_PER_DTYPE_CASE.format(F_if=if_i, F_dtype=dtype, per_bf16_cvt_dispatch=per_bf16_cvt)
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes, F_template = gen_template, F_v3_dispatch = v3_code) return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes, F_template = gen_template, F_v3_dispatch = v3_code)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment