Commit ca2a0ebd authored by danyao12's avatar danyao12
Browse files

Merge branch 'ck_tile/fa_bwd_opt' of https://github.com/ROCm/composable_kernel...

Merge branch 'ck_tile/fa_bwd_opt' of https://github.com/ROCm/composable_kernel into ck_tile/fa_bwd_opt
parents dcc3593f 85ec3c75
...@@ -279,8 +279,8 @@ class FmhaBwdApiPool: ...@@ -279,8 +279,8 @@ class FmhaBwdApiPool:
continue continue
if (spad1 == "t" and trait.spad == "f" and hdim_int <= 64): if (spad1 == "t" and trait.spad == "f" and hdim_int <= 64):
continue continue
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype], F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype],
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
...@@ -490,6 +490,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -490,6 +490,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if receipt == 2: if receipt == 2:
cond = dtype in ['fp16', 'bf16'] cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'alibi'] cond &= bias in ['no', 'alibi']
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
if not cond: if not cond:
continue continue
api_pool.register_dq_dk_dv_traits(k.api_trait()) api_pool.register_dq_dk_dv_traits(k.api_trait())
...@@ -693,7 +694,7 @@ class FmhaBwdConvertQGradKernel: ...@@ -693,7 +694,7 @@ class FmhaBwdConvertQGradKernel:
F_spad : str # true/false F_spad : str # true/false
F_dpad : str # F_dpad : str #
F_mode : str # value from MODE_MAP F_mode : str # value from MODE_MAP
F_occupancy : int # F_occupancy : int #
F_deterministic : str # F_deterministic : str #
@property @property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment