Commit a04af29a authored by Po Yen, Chen's avatar Po Yen, Chen
Browse files

Use qr_2wave pipeline for fp8

parent 9995c65c
...@@ -467,6 +467,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm ...@@ -467,6 +467,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
elif dtype in ['fp8', 'bf8']: elif dtype in ['fp8', 'bf8']:
# no need lse/dropout kernels # no need lse/dropout kernels
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
if receipt == 3:
pipelines.append(FmhaFwdPipeline('qr_2wave', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask))
else:
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask))
else: else:
assert False assert False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment