Commit 9d772b9a authored by Jim's avatar Jim
Browse files

fix padding case

parent 14b4d6bb
...@@ -628,12 +628,12 @@ FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) ...@@ -628,12 +628,12 @@ FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode})
FMHA_BWD_V3_ATOMIC32_INNER_DISPATCH=""" using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>; FMHA_BWD_V3_ATOMIC32_INNER_DISPATCH=""" using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}, false>;
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3{F_padding_suffix}_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;""" return r;"""
FMHA_BWD_V3_ATOMIC16_INNER_DISPATCH=""" using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>; FMHA_BWD_V3_ATOMIC16_INNER_DISPATCH=""" using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3{F_padding_suffix}_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;""" return r;"""
FMHA_BWD_V3_PER_DTYPE_CASE=""" {F_if} (t.data_type.compare(\"{F_dtype}\") == 0) {{ FMHA_BWD_V3_PER_DTYPE_CASE=""" {F_if} (t.data_type.compare(\"{F_dtype}\") == 0) {{
...@@ -811,10 +811,11 @@ class FmhaBwdApiPool: ...@@ -811,10 +811,11 @@ class FmhaBwdApiPool:
if_m = 'if' if m == 0 else 'else if' if_m = 'if' if m == 0 else 'else if'
inners = str() inners = str()
bf16_cvt_tmp = 0 if dtype == "fp16" else bf16_cvt bf16_cvt_tmp = 0 if dtype == "fp16" else bf16_cvt
padding_suffix = "_hdp" if BWD_V3_PADDING_CHECK_MAP[m] == "true" else ""
if is_atomic == "t": if is_atomic == "t":
inners = FMHA_BWD_V3_ATOMIC32_INNER_DISPATCH.format(F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], F_is_causal=BOOL_MAP[is_causal], F_is_atomic32=BOOL_MAP[is_atomic], F_how_v3_bf16_cvt=bf16_cvt_tmp, F_padding=BWD_V3_PADDING_CHECK_MAP[m]) inners = FMHA_BWD_V3_ATOMIC32_INNER_DISPATCH.format(F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], F_is_causal=BOOL_MAP[is_causal], F_is_atomic32=BOOL_MAP[is_atomic], F_how_v3_bf16_cvt=bf16_cvt_tmp, F_padding=BWD_V3_PADDING_CHECK_MAP[m], F_padding_suffix=padding_suffix)
else: else:
inners = FMHA_BWD_V3_ATOMIC16_INNER_DISPATCH.format(F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], F_is_causal=BOOL_MAP[is_causal], F_is_atomic32=BOOL_MAP[is_atomic], F_how_v3_bf16_cvt=bf16_cvt_tmp, F_padding=BWD_V3_PADDING_CHECK_MAP[m]) inners = FMHA_BWD_V3_ATOMIC16_INNER_DISPATCH.format(F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], F_is_causal=BOOL_MAP[is_causal], F_is_atomic32=BOOL_MAP[is_atomic], F_how_v3_bf16_cvt=bf16_cvt_tmp, F_padding=BWD_V3_PADDING_CHECK_MAP[m], F_padding_suffix=padding_suffix)
per_hdim = per_hdim + FMHA_BWD_V3_PER_HDIM_CASE.format(F_if=if_m, F_hdim_expression=BWD_V3_HDIM_CASE_MAP[m], inner_dispatch=inners) per_hdim = per_hdim + FMHA_BWD_V3_PER_HDIM_CASE.format(F_if=if_m, F_hdim_expression=BWD_V3_HDIM_CASE_MAP[m], inner_dispatch=inners)
if_l = 'if' if l == 0 else 'else if' if_l = 'if' if l == 0 else 'else if'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment