"src/vscode:/vscode.git/clone" did not exist on "0814b17129bdbc8a56e325ff7ef6cca695140a44"
Commit 0e2eae42 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Update BitBLAS Benchmark with TileLang Carver Imports and Roller Hints Generation (#148)

- Replace BitBLAS imports with TileLang Carver imports in benchmark_matmul.py
- Modify roller hints generation using new TileLang Carver template and utility functions
- Update get_roller_hints_from_func to handle None cases and improve return logic
- Adjust DefaultPolicy to handle different codegen dictionary formats
parent d0434c3e
......@@ -46,33 +46,29 @@ def get_configs(M, N, K, with_roller=False):
thread numbers, and other parameters to explore during autotuning.
"""
if with_roller:
from bitblas.base.utils import get_roller_hints_from_func
from bitblas.ops.general_matmul.tirscript import matmul_select_implementation
from bitblas.base.arch import CUDA
from bitblas.base.roller.rasterization import NoRasterization
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
from tilelang.carver.roller.rasterization import NoRasterization
arch = CUDA("cuda")
topk = 20
# Simple TIR Compute Expression
ir_module = matmul_select_implementation(
carve_template = MatmulTemplate(
M=M,
N=N,
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
)
).with_arch(arch)
roller_hints = get_roller_hints_from_func(
ir_module,
arch,
topk,
tensorcore_only=True,
allow_gemv=True,
)
func = carve_template.equivalent_function()
assert func is not None, "Function is None"
roller_hints = carve_template.recommend_hints(topk=topk)
if roller_hints is None:
raise ValueError("No Roller Hints Found for TensorCore Scheduling")
configs = []
for hint in roller_hints:
config = {}
......
......@@ -94,7 +94,10 @@ class DefaultPolicy:
self._expand_reduce_axis(td)
for codegen_dicts in self.assign_block_size(td):
results.append(codegen_dicts)
if isinstance(codegen_dicts, dict) and len(codegen_dicts) == 1:
results.append(list(codegen_dicts.values())[0])
else:
results.append(codegen_dicts)
if len(results) >= topk:
break
if len(results) >= topk:
......
......@@ -44,6 +44,7 @@ def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule],
assert func is not None, "The function should not be None"
roller_hints = None
if tensorcore_only:
try:
tensorized_func, tags = get_tensorized_func_and_tags(
......@@ -53,9 +54,9 @@ def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule],
tags = None
if tags and tensorized_func:
policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags)
return policy.emit_config(topk)
roller_hints = policy.emit_config(topk)
else:
return None
roller_hints = None
else:
policy = DefaultPolicy.from_prim_func(func=func, arch=arch)
tensorized_func = None
......@@ -67,7 +68,10 @@ def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule],
tags = None
if tags and tensorized_func:
policy = TensorCorePolicy.from_prim_func(func=tensorized_func, arch=arch, tags=tags)
return policy.emit_config(topk)
roller_hints = policy.emit_config(topk)
else:
roller_hints = None
return roller_hints
def get_roller_hints_from_output_nodes(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment