Commit 0e2eae42 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Update BitBLAS Benchmark with TileLang Carver Imports and Roller Hints Generation (#148)

- Replace BitBLAS imports with TileLang Carver imports in benchmark_matmul.py
- Modify roller hints generation using new TileLang Carver template and utility functions
- Update get_roller_hints_from_func to handle None cases and improve return logic
- Adjust DefaultPolicy to handle different codegen dictionary formats
parent d0434c3e
...@@ -46,33 +46,29 @@ def get_configs(M, N, K, with_roller=False): ...@@ -46,33 +46,29 @@ def get_configs(M, N, K, with_roller=False):
thread numbers, and other parameters to explore during autotuning. thread numbers, and other parameters to explore during autotuning.
""" """
if with_roller: if with_roller:
from bitblas.base.utils import get_roller_hints_from_func from tilelang.carver.template import MatmulTemplate
from bitblas.ops.general_matmul.tirscript import matmul_select_implementation from tilelang.carver.arch import CUDA
from bitblas.base.arch import CUDA from tilelang.carver.roller.rasterization import NoRasterization
from bitblas.base.roller.rasterization import NoRasterization
arch = CUDA("cuda") arch = CUDA("cuda")
topk = 20 topk = 20
# Simple TIR Compute Expression carve_template = MatmulTemplate(
ir_module = matmul_select_implementation(
M=M, M=M,
N=N, N=N,
K=K, K=K,
in_dtype="float16", in_dtype="float16",
out_dtype="float16", out_dtype="float16",
accum_dtype="float16", accum_dtype="float16",
) ).with_arch(arch)
roller_hints = get_roller_hints_from_func( func = carve_template.equivalent_function()
ir_module, assert func is not None, "Function is None"
arch,
topk, roller_hints = carve_template.recommend_hints(topk=topk)
tensorcore_only=True,
allow_gemv=True,
)
if roller_hints is None: if roller_hints is None:
raise ValueError("No Roller Hints Found for TensorCore Scheduling") raise ValueError("No Roller Hints Found for TensorCore Scheduling")
configs = [] configs = []
for hint in roller_hints: for hint in roller_hints:
config = {} config = {}
......
...@@ -94,7 +94,10 @@ class DefaultPolicy: ...@@ -94,7 +94,10 @@ class DefaultPolicy:
self._expand_reduce_axis(td) self._expand_reduce_axis(td)
for codegen_dicts in self.assign_block_size(td): for codegen_dicts in self.assign_block_size(td):
results.append(codegen_dicts) if isinstance(codegen_dicts, dict) and len(codegen_dicts) == 1:
results.append(list(codegen_dicts.values())[0])
else:
results.append(codegen_dicts)
if len(results) >= topk: if len(results) >= topk:
break break
if len(results) >= topk: if len(results) >= topk:
......
...@@ -44,6 +44,7 @@ def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule], ...@@ -44,6 +44,7 @@ def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule],
assert func is not None, "The function should not be None" assert func is not None, "The function should not be None"
roller_hints = None
if tensorcore_only: if tensorcore_only:
try: try:
tensorized_func, tags = get_tensorized_func_and_tags( tensorized_func, tags = get_tensorized_func_and_tags(
...@@ -53,9 +54,9 @@ def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule], ...@@ -53,9 +54,9 @@ def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule],
tags = None tags = None
if tags and tensorized_func: if tags and tensorized_func:
policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags)
return policy.emit_config(topk) roller_hints = policy.emit_config(topk)
else: else:
return None roller_hints = None
else: else:
policy = DefaultPolicy.from_prim_func(func=func, arch=arch) policy = DefaultPolicy.from_prim_func(func=func, arch=arch)
tensorized_func = None tensorized_func = None
...@@ -67,7 +68,10 @@ def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule], ...@@ -67,7 +68,10 @@ def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule],
tags = None tags = None
if tags and tensorized_func: if tags and tensorized_func:
policy = TensorCorePolicy.from_prim_func(func=tensorized_func, arch=arch, tags=tags) policy = TensorCorePolicy.from_prim_func(func=tensorized_func, arch=arch, tags=tags)
return policy.emit_config(topk) roller_hints = policy.emit_config(topk)
else:
roller_hints = None
return roller_hints
def get_roller_hints_from_output_nodes( def get_roller_hints_from_output_nodes(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment