Commit 098487d7 authored by Andy Lugo's avatar Andy Lugo
Browse files

Review comments

parent 6f5c8e3a
...@@ -487,14 +487,14 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm ...@@ -487,14 +487,14 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
if kernel_filter != None: if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter): if not fnmatch.fnmatch(k.name, kernel_filter):
continue continue
if receipt == 2: if receipt in (2, 3):
cond = dtype in ['fp16', 'bf16'] cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row' cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'alibi'] cond &= pipeline.F_bias in ['no', 'alibi']
cond &= pipeline.F_squant == 'f' cond &= pipeline.F_squant == 'f'
if not cond: if not cond:
continue continue
if receipt in (3, 4): if receipt == 4:
cond = dtype in ['fp16', 'bf16'] cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row' cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'bias'] cond &= pipeline.F_bias in ['no', 'bias']
......
...@@ -103,7 +103,8 @@ if __name__ == "__main__": ...@@ -103,7 +103,8 @@ if __name__ == "__main__":
required=False, required=False,
help="codegen receipt. 0: generate only 8xhdim coverage\n" + \ help="codegen receipt. 0: generate only 8xhdim coverage\n" + \
" 1: generate more instance to cover all hdim\n" + \ " 1: generate more instance to cover all hdim\n" + \
" 2: Only generate instance for Flash attention integration" " 2: Only generate instance for Flash attention integration\n" + \
" 4: Only generate instance for PyTorch integration"
) )
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment