Commit 8b9edc3e authored by Chaofan Lin's avatar Chaofan Lin Committed by GitHub
Browse files

[Refactor] Remove BitBLAS Import in Benchmark (#150)

parent 0e2eae42
...@@ -149,13 +149,6 @@ def matmul(M, N, K, with_roller): ...@@ -149,13 +149,6 @@ def matmul(M, N, K, with_roller):
# - A reference program for correctness verification # - A reference program for correctness verification
# - The "tvm" profiler backend # - The "tvm" profiler backend
# - HIP as the compilation target (modify as needed for your hardware) # - HIP as the compilation target (modify as needed for your hardware)
if with_roller:
# check out bitblas is installed
try:
import bitblas # noqa: F401
except ImportError as e:
raise ImportError(
"BitBlas is not installed. Please install it via 'pip install bitblas'.") from e
@autotune( @autotune(
configs=get_configs(M, N, K, with_roller), configs=get_configs(M, N, K, with_roller),
......
...@@ -49,33 +49,30 @@ def get_configs(M, N, K, with_roller=False): ...@@ -49,33 +49,30 @@ def get_configs(M, N, K, with_roller=False):
thread numbers, and other parameters to explore during autotuning. thread numbers, and other parameters to explore during autotuning.
""" """
if with_roller: if with_roller:
from bitblas.base.utils import get_roller_hints_from_func from tilelang.carver.template import MatmulTemplate
from bitblas.ops.general_matmul.tirscript import matmul_select_implementation from tilelang.carver.arch import CUDA
from bitblas.base.arch import CUDA from tilelang.carver.roller.rasterization import NoRasterization
from bitblas.base.roller.rasterization import NoRasterization
arch = CUDA("cuda") arch = CUDA("cuda")
topk = 20 topk = 20
# Simple TIR Compute Expression # Simple TIR Compute Expression
ir_module = matmul_select_implementation( carve_template = MatmulTemplate(
M=M, M=M,
N=N, N=N,
K=K, K=K,
in_dtype="float16", in_dtype="float16",
out_dtype="float16", out_dtype="float16",
accum_dtype="float16", accum_dtype="float16",
) ).with_arch(arch)
roller_hints = get_roller_hints_from_func( func = carve_template.equivalent_function()
ir_module, assert func is not None, "Function is None"
arch,
topk, roller_hints = carve_template.recommend_hints(topk=topk)
tensorcore_only=True,
allow_gemv=True,
)
if roller_hints is None: if roller_hints is None:
raise ValueError("No Roller Hints Found for TensorCore Scheduling") raise ValueError("No Roller Hints Found for TensorCore Scheduling")
configs = [] configs = []
for hint in roller_hints: for hint in roller_hints:
config = {} config = {}
...@@ -156,13 +153,6 @@ def matmul(M, N, K, with_roller): ...@@ -156,13 +153,6 @@ def matmul(M, N, K, with_roller):
# - A reference program for correctness verification # - A reference program for correctness verification
# - The "tvm" profiler backend # - The "tvm" profiler backend
# - HIP as the compilation target (modify as needed for your hardware) # - HIP as the compilation target (modify as needed for your hardware)
if with_roller:
# check out bitblas is installed
try:
import bitblas # noqa: F401
except ImportError as e:
raise ImportError(
"BitBlas is not installed. Please install it via 'pip install bitblas'.") from e
@autotune( @autotune(
configs=get_configs(M, N, K, with_roller), configs=get_configs(M, N, K, with_roller),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment