Commit 7d45045c authored by Jim's avatar Jim
Browse files

update: add restricts

parent 7857f621
......@@ -162,8 +162,8 @@ std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp"
FMHA_BWD_V3_TEMPLATE="""template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr const char * bwd_v3_name = "bwd_v3{F_hdim_name}{F_dtype_name}{F_causal_name}{F_atomic_name}{F_bf16_cvt_name}"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr unsigned char * bwd_v3_buf = bwd{F_hdim_name}{F_dtype_name}{F_causal_name}{F_atomic_name}{F_bf16_cvt_name}; }};
FMHA_BWD_V3_TEMPLATE="""template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr const char * bwd_v3_name = "bwd_v3{F_hdim_name}{F_dtype_name}{F_causal_name}{F_atomic_name}{F_bf16_cvt_name}{F_hdpad_name}"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr unsigned char * bwd_v3_buf = bwd{F_hdim_name}{F_dtype_name}{F_causal_name}{F_atomic_name}{F_bf16_cvt_name}{F_hdpad_name}; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr int ts_qo = {F_Ts_qo}; static constexpr int ts_kv = 192; }};
"""
......@@ -782,14 +782,18 @@ class FmhaBwdApiPool:
hdim = int(hdim)
Ts_qo = 32 if hdim == 64 else 16
for k, trait in enumerate(traits):
if hdim == 64 and trait.is_hdpad == "t":
continue
hdim_name = "_hd64" if hdim == 64 else ""
dtype_name = "_{}".format(dtype)
causal_name = "_causal" if trait.is_causal == "t" else ""
atomic_name = "_a32" if trait.is_atomic == "t" else "_a16"
bf16_cvt_name = "_{}".format(BF16_CVT_MAP[trait.bf16_cvt])
bf16_cvt_name = bf16_cvt_name if dtype == "bf16" else ""
hdpad_name = "_pddv" if trait.is_hdpad == "t" else ""
gen_template = gen_template + FMHA_BWD_V3_TEMPLATE.format(F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], F_is_atomic=BOOL_MAP[trait.is_atomic],
F_is_causal=BOOL_MAP[trait.is_causal], F_bf16_cvt=trait.bf16_cvt, F_hdpad=BOOL_MAP[trait.is_hdpad], F_Ts_qo = Ts_qo, F_hdim_name=hdim_name,
F_dtype_name=dtype_name, F_causal_name=causal_name, F_atomic_name=atomic_name, F_bf16_cvt_name=bf16_cvt_name)
F_dtype_name=dtype_name, F_causal_name=causal_name, F_atomic_name=atomic_name, F_bf16_cvt_name=bf16_cvt_name, F_hdpad_name=hdpad_name)
v3_code = str()
for i, dtype in enumerate(self.dq_dk_dv_v3_pool.keys()):
......@@ -1011,14 +1015,14 @@ class FmhaBwdV3DQDKDVKernel:
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
'32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"],
# '32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
# "kr_ktr_vr_iglp", "kr_ktr_vr"],
'64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"]
# '256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
# "kr_ktr_vr_iglp", "kr_ktr_vr"]
}
else:
return None
......@@ -1061,8 +1065,11 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
continue
if receipt == 3:
cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'alibi']
cond &= bias in ['no']
cond &= dropout in ['no']
cond &= dpad == dvpad
cond &= spad == skpad
cond &= spad == "f"
cond &= deterministic == "f"
if not cond:
continue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment