Commit a0c92495 authored by danyao12's avatar danyao12
Browse files

codegen update

parent 7e9d2390
......@@ -277,7 +277,7 @@ class FmhaBwdApiPool:
for spad1 in ["t", "f"]:
if (spad1 == "f" and (trait.spad == "t" or trait.mode == "group")):
continue
if (spad1 == "t" and trait.spad == "f" and hdim_int <= 64):
if (spad1 == "t" and trait.spad == "f" and hdim_int == 64):
continue
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment