Commit c1e2fef7 authored by rocking's avatar rocking
Browse files

Add receipt for aiter integration

parent 052a7265
...@@ -499,13 +499,30 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -499,13 +499,30 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
cond &= dpad == dvpad cond &= dpad == dvpad
if not cond: if not cond:
continue continue
if receipt == 3: elif receipt == 3:
cond = dtype in ['fp16', 'bf16'] cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'alibi'] cond &= bias in ['no', 'alibi']
cond &= dpad == dvpad cond &= dpad == dvpad
cond &= deterministic == "f" cond &= deterministic == "f"
if not cond: if not cond:
continue continue
elif receipt == 10:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "batch"
cond &= bias in ['no', 'alibi']
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
cond &= dpad == dvpad
cond &= deterministic == "f"
if not cond:
continue
elif receipt == 11:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "group"
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
cond &= dpad == dvpad
cond &= deterministic == "f"
if not cond:
continue
api_pool.register_dq_dk_dv_traits(k.api_trait()) api_pool.register_dq_dk_dv_traits(k.api_trait())
gen.append(k) gen.append(k)
......
...@@ -494,6 +494,22 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm ...@@ -494,6 +494,22 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
cond &= pipeline.F_squant == 'f' cond &= pipeline.F_squant == 'f'
if not cond: if not cond:
continue continue
elif receipt == 10:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "batch"
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'alibi']
cond &= pipeline.F_squant == 'f'
if not cond:
continue
elif receipt == 11:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "group"
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'alibi']
cond &= pipeline.F_squant == 'f'
if not cond:
continue
api_pool.register_traits(k.api_trait()) api_pool.register_traits(k.api_trait())
gen.append(k) gen.append(k)
......
...@@ -712,6 +712,20 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -712,6 +712,20 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
cond &= pipeline.F_squant == 'f' cond &= pipeline.F_squant == 'f'
if not cond: if not cond:
continue continue
elif receipt == 11:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "group"
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
if not cond:
continue
elif receipt == 14:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "batch"
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
if not cond:
continue
api_pool.register_traits(k.api_trait()) api_pool.register_traits(k.api_trait())
gen.append(k) gen.append(k)
......
...@@ -103,7 +103,10 @@ if __name__ == "__main__": ...@@ -103,7 +103,10 @@ if __name__ == "__main__":
required=False, required=False,
help="codegen receipt. 0: generate only 8xhdim coverage\n" + \ help="codegen receipt. 0: generate only 8xhdim coverage\n" + \
" 1: generate more instance to cover all hdim\n" + \ " 1: generate more instance to cover all hdim\n" + \
" 2: Only generate instance for Flash attention integration" " 2: Only generate instance for Flash attention integration" + \
" 10: Only generate instance for Aiter(mha_fwd, mha_bwd) integration"
" 11: Only generate instance for Aiter(mha_varlen_fwd, mha_varlen_bwd) integration"
" 12: Only generate instance for Aiter(mha_fwd_kvcache) integration"
) )
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment