"ts/webui/src/static/style/index.css" did not exist on "e2d8cc1bcc83bb4485467d738beb19984e6b9cea"
Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
...@@ -32,7 +32,6 @@ block_K = 32 ...@@ -32,7 +32,6 @@ block_K = 32
def test_warp_specialized(): def test_warp_specialized():
@T.prim_func @T.prim_func
def before(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): def before(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)):
bx = T.launch_thread("blockIdx.x", 8) bx = T.launch_thread("blockIdx.x", 8)
...@@ -47,25 +46,27 @@ def test_warp_specialized(): ...@@ -47,25 +46,27 @@ def test_warp_specialized():
for k in T.serial(16, annotations={"num_stages": T.int32(3)}): for k in T.serial(16, annotations={"num_stages": T.int32(3)}):
if v == 0: if v == 0:
T.tma_load( T.tma_load(
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0),
2, 0), 0, 0,
T.tvm_access_ptr( T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2),
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), k * 32,
k * 32, by * 64) by * 64,
)
if v == 0: if v == 0:
T.tma_load( T.tma_load(
T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0),
2, 0), 0, 0,
T.tvm_access_ptr( T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2),
T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2), bx * 64,
bx * 64, k * 32) k * 32,
)
T.call_extern( T.call_extern(
"handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", "handle",
T.tvm_access_ptr( "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr( T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1),
T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3)) )
@T.prim_func @T.prim_func
def after(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): def after(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)):
...@@ -85,34 +86,35 @@ def test_warp_specialized(): ...@@ -85,34 +86,35 @@ def test_warp_specialized():
T.mbarrier_expect_tx(T.get_mbarrier(k % 3), 4096) T.mbarrier_expect_tx(T.get_mbarrier(k % 3), 4096)
if v - 128 == 0: if v - 128 == 0:
T.tma_load( T.tma_load(
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0),
2, 0), T.get_mbarrier(k % 3), T.get_mbarrier(k % 3),
T.tvm_access_ptr( T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2),
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), k * 32,
k * 32, by * 64) by * 64,
)
if v - 128 == 0: if v - 128 == 0:
T.mbarrier_expect_tx(T.get_mbarrier(k % 3), 4096) T.mbarrier_expect_tx(T.get_mbarrier(k % 3), 4096)
if v - 128 == 0: if v - 128 == 0:
T.tma_load( T.tma_load(
T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0),
2, 0), T.get_mbarrier(k % 3), T.get_mbarrier(k % 3),
T.tvm_access_ptr( T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2),
T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2), bx * 64,
bx * 64, k * 32) k * 32,
)
T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3)])) T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3)]))
else: else:
T.set_max_nreg(240, 1) T.set_max_nreg(240, 1)
for k in range(16): for k in range(16):
T.mbarrier_wait_parity(T.get_mbarrier(k % 3), k // 3 % 2) T.mbarrier_wait_parity(T.get_mbarrier(k % 3), k // 3 % 2)
T.call_extern( T.call_extern(
"handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", "handle",
T.tvm_access_ptr( "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr( T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1),
T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3)) )
T.evaluate( T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)]))
tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)]))
_check(before, after) _check(before, after)
......
...@@ -6,7 +6,7 @@ from tilelang.utils.sparse import compress_sm90, randn_semi_sparse ...@@ -6,7 +6,7 @@ from tilelang.utils.sparse import compress_sm90, randn_semi_sparse
def _test_compress_sm90(M, K, block_k, dtype): def _test_compress_sm90(M, K, block_k, dtype):
A = randn_semi_sparse(M, K, dtype=dtype, device='cuda') A = randn_semi_sparse(M, K, dtype=dtype, device="cuda")
A_sparse, E = compress_sm90(A, block_k, False) A_sparse, E = compress_sm90(A, block_k, False)
......
...@@ -5,12 +5,11 @@ import tilelang.language as T ...@@ -5,12 +5,11 @@ import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
......
...@@ -23,6 +23,7 @@ def _compute_version() -> str: ...@@ -23,6 +23,7 @@ def _compute_version() -> str:
if version_file.is_file(): if version_file.is_file():
try: try:
from version_provider import dynamic_metadata # type: ignore from version_provider import dynamic_metadata # type: ignore
return dynamic_metadata("version") return dynamic_metadata("version")
except Exception: except Exception:
# Fall back to the raw VERSION file if provider isn't available. # Fall back to the raw VERSION file if provider isn't available.
...@@ -33,6 +34,7 @@ def _compute_version() -> str: ...@@ -33,6 +34,7 @@ def _compute_version() -> str:
try: try:
from importlib.metadata import version as _dist_version # py3.8+ from importlib.metadata import version as _dist_version # py3.8+
return _dist_version("tilelang") return _dist_version("tilelang")
except Exception as exc: except Exception as exc:
warnings.warn( warnings.warn(
......
from __future__ import annotations from __future__ import annotations
from tvm import tir from tvm import tir
from tvm.tir import (PyStmtExprVisitor, BufferStore, For, Var, PrimFunc, BufferLoad, IntImm) from tvm.tir import PyStmtExprVisitor, BufferStore, For, Var, PrimFunc, BufferLoad, IntImm
from tvm.tir.transform import prim_func_pass from tvm.tir.transform import prim_func_pass
from tvm.tir.stmt_functor import post_order_visit from tvm.tir.stmt_functor import post_order_visit
...@@ -22,14 +22,14 @@ class _LoopVarUseAnalyzer(PyStmtExprVisitor): ...@@ -22,14 +22,14 @@ class _LoopVarUseAnalyzer(PyStmtExprVisitor):
def collect_local_buffer_accesses(statement) -> list[BufferLoad | BufferStore]: def collect_local_buffer_accesses(statement) -> list[BufferLoad | BufferStore]:
""" """
Collect local buffer accesses in the loop body. Collect local buffer accesses in the loop body.
Args: Args:
statement: The TIR statement to analyze statement: The TIR statement to analyze
Returns: Returns:
Tuple of buffer accesses in the loop body. Tuple of buffer accesses in the loop body.
""" """
buffer_accesses = [] buffer_accesses = []
...@@ -44,7 +44,6 @@ def collect_local_buffer_accesses(statement) -> list[BufferLoad | BufferStore]: ...@@ -44,7 +44,6 @@ def collect_local_buffer_accesses(statement) -> list[BufferLoad | BufferStore]:
@tir.functor.visitor @tir.functor.visitor
class _FragmentLoopCheckVisitor(PyStmtExprVisitor): class _FragmentLoopCheckVisitor(PyStmtExprVisitor):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
...@@ -75,7 +74,8 @@ class _FragmentLoopCheckVisitor(PyStmtExprVisitor): ...@@ -75,7 +74,8 @@ class _FragmentLoopCheckVisitor(PyStmtExprVisitor):
raise ValueError( raise ValueError(
"[Tilelang Semantic Check] " "[Tilelang Semantic Check] "
f"Loop variable {loop.loop_var} in a T.Parallel loop with symbolic range (min={loop.min}, extent={loop.extent}) is used to index " f"Loop variable {loop.loop_var} in a T.Parallel loop with symbolic range (min={loop.min}, extent={loop.extent}) is used to index "
"a local/fragment buffer, which is not allowed in Tilelang.") "a local/fragment buffer, which is not allowed in Tilelang."
)
return return
......
...@@ -23,10 +23,7 @@ def print_fragment_format(layout: T.Fragment) -> str: ...@@ -23,10 +23,7 @@ def print_fragment_format(layout: T.Fragment) -> str:
if isinstance(layout, T.Fragment): if isinstance(layout, T.Fragment):
input_shape = layout.get_input_shape() input_shape = layout.get_input_shape()
output_shape = layout.get_output_shape() output_shape = layout.get_output_shape()
lines = [ lines = [f" Shape: {input_shape} -> {output_shape}", f" Thread: {layout.forward_thread}", f" Index: {layout.forward_index}"]
f" Shape: {input_shape} -> {output_shape}", f" Thread: {layout.forward_thread}",
f" Index: {layout.forward_index}"
]
print("\n".join(lines)) print("\n".join(lines))
else: else:
raise ValueError(f"Expected T.Fragment, but got {type(layout).__name__}") raise ValueError(f"Expected T.Fragment, but got {type(layout).__name__}")
...@@ -82,7 +79,6 @@ class _LayoutVisualVisitor(PyStmtExprVisitor): ...@@ -82,7 +79,6 @@ class _LayoutVisualVisitor(PyStmtExprVisitor):
def LayoutVisual(formats: str = ""): def LayoutVisual(formats: str = ""):
def pass_fn(func: tir.PrimFunc, mod, ctx): def pass_fn(func: tir.PrimFunc, mod, ctx):
_LayoutVisualVisitor(formats=formats).visit_stmt(func.body) _LayoutVisualVisitor(formats=formats).visit_stmt(func.body)
return func return func
......
...@@ -11,10 +11,7 @@ from tvm.tir.transform import prim_func_pass ...@@ -11,10 +11,7 @@ from tvm.tir.transform import prim_func_pass
def is_pipelined_for(op: For) -> bool: def is_pipelined_for(op: For) -> bool:
"""Check if a for loop is pipelined.""" """Check if a for loop is pipelined."""
anno_keys = [ anno_keys = ["num_stages", "tl_pipeline_order", "tl_pipeline_stage", "tl_pipeline_sync", "tl_pipeline_group"]
"num_stages", "tl_pipeline_order", "tl_pipeline_stage", "tl_pipeline_sync",
"tl_pipeline_group"
]
return any(key in op.annotations for key in anno_keys) return any(key in op.annotations for key in anno_keys)
...@@ -26,7 +23,6 @@ def is_tile_op(op: Call) -> bool: ...@@ -26,7 +23,6 @@ def is_tile_op(op: Call) -> bool:
@tir.functor.visitor @tir.functor.visitor
class _NestedLoopCheckVisitor(PyStmtExprVisitor): class _NestedLoopCheckVisitor(PyStmtExprVisitor):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.in_parallel_context = False self.in_parallel_context = False
...@@ -42,27 +38,24 @@ class _NestedLoopCheckVisitor(PyStmtExprVisitor): ...@@ -42,27 +38,24 @@ class _NestedLoopCheckVisitor(PyStmtExprVisitor):
# Otherwise # Otherwise
if self.in_parallel_context: if self.in_parallel_context:
raise ValueError("[Tilelang Semantic Check] " raise ValueError("[Tilelang Semantic Check] Nested parallel loops are not allowed. Please check your loop structure.")
"Nested parallel loops are not allowed. "
"Please check your loop structure.")
self.in_parallel_context = True self.in_parallel_context = True
super().visit_for_(op) super().visit_for_(op)
self.in_parallel_context = False self.in_parallel_context = False
return return
elif is_pipelined_for(op): elif is_pipelined_for(op):
if self.in_parallel_context: if self.in_parallel_context:
raise ValueError("[Tilelang Semantic Check] " raise ValueError(
"Pipelined loop cannot be nested inside a parallel loop. " "[Tilelang Semantic Check] Pipelined loop cannot be nested inside a parallel loop. Please check your loop structure."
"Please check your loop structure.") )
super().visit_for_(op) super().visit_for_(op)
def visit_call_(self, op: Call) -> None: def visit_call_(self, op: Call) -> None:
if self.in_parallel_context and is_tile_op(op): if self.in_parallel_context and is_tile_op(op):
raise ValueError("[Tilelang Semantic Check] " raise ValueError(
"Only elementwise operations are allowed inside a parallel loop. " \ f'[Tilelang Semantic Check] Only elementwise operations are allowed inside a parallel loop. Got a tile-op "{op.op}".'
f"Got a tile-op \"{op.op}\"." )
)
def NestedLoopChecker(): def NestedLoopChecker():
......
...@@ -85,8 +85,7 @@ def _get_current_stack() -> CaptureStack: ...@@ -85,8 +85,7 @@ def _get_current_stack() -> CaptureStack:
class AutotuneInputsCapture: class AutotuneInputsCapture:
__slots__ = "tensors"
__slots__ = ("tensors")
def __init__(self, tensors: list[Any]): def __init__(self, tensors: list[Any]):
self.tensors = tensors self.tensors = tensors
......
"""The auto-tune parameters. """The auto-tune parameters."""
"""
from __future__ import annotations from __future__ import annotations
import tilelang import tilelang
...@@ -50,7 +50,7 @@ class CompileArgs: ...@@ -50,7 +50,7 @@ class CompileArgs:
out_idx: list[int] | int | None = None out_idx: list[int] | int | None = None
execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto" execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto"
target: Literal['auto', 'cuda', 'hip'] = 'auto' target: Literal["auto", "cuda", "hip"] = "auto"
target_host: str | Target = None target_host: str | Target = None
verbose: bool = False verbose: bool = False
pass_configs: dict[str, Any] | None = None pass_configs: dict[str, Any] | None = None
...@@ -62,24 +62,20 @@ class CompileArgs: ...@@ -62,24 +62,20 @@ class CompileArgs:
target=self.target, target=self.target,
target_host=self.target_host, target_host=self.target_host,
verbose=self.verbose, verbose=self.verbose,
pass_configs=self.pass_configs) pass_configs=self.pass_configs,
)
def __hash__(self): def __hash__(self):
data = { data = {
"execution_backend": "execution_backend": self.execution_backend,
self.execution_backend, "target": str(self.target),
"target": "target_host": str(self.target_host) if self.target_host else None,
str(self.target), "verbose": self.verbose,
"target_host": "pass_configs": json.dumps(self.pass_configs, sort_keys=True) if self.pass_configs else None,
str(self.target_host) if self.target_host else None,
"verbose":
self.verbose,
"pass_configs":
json.dumps(self.pass_configs, sort_keys=True) if self.pass_configs else None,
} }
hash_obj = hashlib.sha256(json.dumps(data, sort_keys=True).encode('utf-8')) hash_obj = hashlib.sha256(json.dumps(data, sort_keys=True).encode("utf-8"))
return int.from_bytes(hash_obj.digest(), byteorder='big') return int.from_bytes(hash_obj.digest(), byteorder="big")
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -104,6 +100,7 @@ class ProfileArgs: ...@@ -104,6 +100,7 @@ class ProfileArgs:
manual_check_prog: Callable = None manual_check_prog: Callable = None
cache_input_tensors: bool = True cache_input_tensors: bool = True
""" """
warmup: int = 25 warmup: int = 25
rep: int = 100 rep: int = 100
timeout: int = 30 timeout: int = 30
...@@ -127,8 +124,8 @@ class ProfileArgs: ...@@ -127,8 +124,8 @@ class ProfileArgs:
"atol": self.atol, "atol": self.atol,
"max_mismatched_ratio": self.max_mismatched_ratio, "max_mismatched_ratio": self.max_mismatched_ratio,
} }
hash_obj = hashlib.sha256(json.dumps(data, sort_keys=True).encode('utf-8')) hash_obj = hashlib.sha256(json.dumps(data, sort_keys=True).encode("utf-8"))
return int.from_bytes(hash_obj.digest(), byteorder='big') return int.from_bytes(hash_obj.digest(), byteorder="big")
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -143,6 +140,7 @@ class AutotuneResult: ...@@ -143,6 +140,7 @@ class AutotuneResult:
func: Optimized function. func: Optimized function.
kernel: Compiled kernel function. kernel: Compiled kernel function.
""" """
latency: float | None = None latency: float | None = None
config: dict | None = None config: dict | None = None
ref_latency: float | None = None ref_latency: float | None = None
...@@ -199,8 +197,7 @@ class AutotuneResult: ...@@ -199,8 +197,7 @@ class AutotuneResult:
if verbose: if verbose:
logger.debug(f"Saving kernel source code to file: {device_kernel_path}") logger.debug(f"Saving kernel source code to file: {device_kernel_path}")
if kernel.kernel_source is not None: if kernel.kernel_source is not None:
self._safe_write_file(device_kernel_path, "w", self._safe_write_file(device_kernel_path, "w", lambda f: f.write(kernel.kernel_source))
lambda f: f.write(kernel.kernel_source))
except Exception as e: except Exception as e:
logger.error(f"Error saving kernel source code to disk: {e}") logger.error(f"Error saving kernel source code to disk: {e}")
...@@ -211,11 +208,9 @@ class AutotuneResult: ...@@ -211,11 +208,9 @@ class AutotuneResult:
logger.debug(f"Saving wrapped kernel source code to file: {host_kernel_path}") logger.debug(f"Saving wrapped kernel source code to file: {host_kernel_path}")
# Match kernel_cache behavior: use host source for tvm_ffi, otherwise wrapped kernel # Match kernel_cache behavior: use host source for tvm_ffi, otherwise wrapped kernel
if kernel.execution_backend == "tvm_ffi": if kernel.execution_backend == "tvm_ffi":
self._safe_write_file(host_kernel_path, "w", self._safe_write_file(host_kernel_path, "w", lambda f: f.write(kernel.adapter.get_host_source()))
lambda f: f.write(kernel.adapter.get_host_source()))
else: else:
self._safe_write_file(host_kernel_path, "w", self._safe_write_file(host_kernel_path, "w", lambda f: f.write(kernel.adapter.get_kernel_source()))
lambda f: f.write(kernel.adapter.get_kernel_source()))
except Exception as e: except Exception as e:
logger.error(f"Error saving wrapped kernel source code to disk: {e}") logger.error(f"Error saving wrapped kernel source code to disk: {e}")
...@@ -237,12 +232,10 @@ class AutotuneResult: ...@@ -237,12 +232,10 @@ class AutotuneResult:
py_src_path = src_lib_path.replace(".cubin", ".py") py_src_path = src_lib_path.replace(".cubin", ".py")
if verbose: if verbose:
logger.debug(f"Saving kernel nvrtc python code to file: {kernel_py_path}") logger.debug(f"Saving kernel nvrtc python code to file: {kernel_py_path}")
self._safe_write_file(kernel_py_path, "wb", self._safe_write_file(kernel_py_path, "wb", lambda f: f.write(self._load_binary(py_src_path)))
lambda f: f.write(self._load_binary(py_src_path)))
if verbose: if verbose:
logger.debug(f"Saving kernel library to file: {kernel_lib_path}") logger.debug(f"Saving kernel library to file: {kernel_lib_path}")
self._safe_write_file(kernel_lib_path, "wb", self._safe_write_file(kernel_lib_path, "wb", lambda f: f.write(self._load_binary(src_lib_path)))
lambda f: f.write(self._load_binary(src_lib_path)))
elif kernel.execution_backend == "tvm_ffi": elif kernel.execution_backend == "tvm_ffi":
executable = kernel.adapter.executable executable = kernel.adapter.executable
if verbose: if verbose:
...@@ -252,8 +245,7 @@ class AutotuneResult: ...@@ -252,8 +245,7 @@ class AutotuneResult:
src_lib_path = kernel.adapter.libpath src_lib_path = kernel.adapter.libpath
if verbose: if verbose:
logger.debug(f"Saving kernel library to file: {kernel_lib_path}") logger.debug(f"Saving kernel library to file: {kernel_lib_path}")
self._safe_write_file(kernel_lib_path, "wb", self._safe_write_file(kernel_lib_path, "wb", lambda f: f.write(self._load_binary(src_lib_path)))
lambda f: f.write(self._load_binary(src_lib_path)))
except Exception as e: except Exception as e:
logger.error(f"Error saving kernel library to disk: {e}") logger.error(f"Error saving kernel library to disk: {e}")
...@@ -370,14 +362,12 @@ class AutotuneResult: ...@@ -370,14 +362,12 @@ class AutotuneResult:
# save best config (atomic) # save best config (atomic)
if verbose: if verbose:
logger.debug(f"Saving best config to file: {path / BEST_CONFIG_PATH}") logger.debug(f"Saving best config to file: {path / BEST_CONFIG_PATH}")
self._safe_write_file( self._safe_write_file(str(path / BEST_CONFIG_PATH), "w", lambda f: json.dump(self.config, f))
str(path / BEST_CONFIG_PATH), "w", lambda f: json.dump(self.config, f))
# save function (atomic) # save function (atomic)
if verbose: if verbose:
logger.debug(f"Saving function to file: {path / FUNCTION_PATH}") logger.debug(f"Saving function to file: {path / FUNCTION_PATH}")
self._safe_write_file( self._safe_write_file(str(path / FUNCTION_PATH), "wb", lambda f: cloudpickle.dump(self.func, f))
str(path / FUNCTION_PATH), "wb", lambda f: cloudpickle.dump(self.func, f))
# save ref latency (atomic) # save ref latency (atomic)
if verbose: if verbose:
...@@ -385,10 +375,13 @@ class AutotuneResult: ...@@ -385,10 +375,13 @@ class AutotuneResult:
self._safe_write_file( self._safe_write_file(
str(path / LATENCY_PATH), str(path / LATENCY_PATH),
"w", "w",
lambda f: json.dump({ lambda f: json.dump(
"latency": self.latency, {
"ref_latency": self.ref_latency, "latency": self.latency,
}, f), "ref_latency": self.ref_latency,
},
f,
),
) )
# save kernel # save kernel
...@@ -403,8 +396,8 @@ class AutotuneResult: ...@@ -403,8 +396,8 @@ class AutotuneResult:
# Normalize target and resolve execution backend for loading # Normalize target and resolve execution backend for loading
from tilelang.utils.target import determine_target as _determine_target from tilelang.utils.target import determine_target as _determine_target
from tilelang.jit.execution_backend import resolve_execution_backend from tilelang.jit.execution_backend import resolve_execution_backend
norm_target = Target(_determine_target(compile_args.target)) if isinstance(
compile_args.target, str) else compile_args.target norm_target = Target(_determine_target(compile_args.target)) if isinstance(compile_args.target, str) else compile_args.target
requested_backend = compile_args.execution_backend requested_backend = compile_args.execution_backend
resolved_backend = resolve_execution_backend(requested_backend, norm_target) resolved_backend = resolve_execution_backend(requested_backend, norm_target)
# load best config # load best config
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
This module provides functionality for auto-tuning tilelang programs, including JIT compilation This module provides functionality for auto-tuning tilelang programs, including JIT compilation
and performance optimization through configuration search. and performance optimization through configuration search.
""" """
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
...@@ -14,7 +15,8 @@ from tvm.tir import PrimFunc, Var ...@@ -14,7 +15,8 @@ from tvm.tir import PrimFunc, Var
from tvm.target import Target from tvm.target import Target
import inspect import inspect
from functools import partial from functools import partial
from typing import (Callable, Generic, Literal, Any, TypeVar) from typing import Callable, Generic, Literal, Any, TypeVar
# Python 3.9 compatibility for ParamSpec # Python 3.9 compatibility for ParamSpec
try: try:
from typing import ParamSpec from typing import ParamSpec
...@@ -74,8 +76,8 @@ def _init_logger_handlers(): ...@@ -74,8 +76,8 @@ def _init_logger_handlers():
global _logger_handlers_initialized global _logger_handlers_initialized
if _logger_handlers_initialized: if _logger_handlers_initialized:
return return
formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s') formatter = logging.Formatter("%(asctime)s %(levelname)s:%(message)s")
file_handler = logging.FileHandler('autotuner.log', mode='w') file_handler = logging.FileHandler("autotuner.log", mode="w")
file_handler.setLevel(logging.DEBUG) file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)
console_handler = logging.StreamHandler(sys.stdout) console_handler = logging.StreamHandler(sys.stdout)
...@@ -87,8 +89,7 @@ def _init_logger_handlers(): ...@@ -87,8 +89,7 @@ def _init_logger_handlers():
def get_available_cpu_count() -> int: def get_available_cpu_count() -> int:
"""Gets the number of CPU cores available to the current process. """Gets the number of CPU cores available to the current process."""
"""
try: try:
cpu_count = len(os.sched_getaffinity(0)) cpu_count = len(os.sched_getaffinity(0))
except AttributeError: except AttributeError:
...@@ -107,6 +108,7 @@ class AutoTuner: ...@@ -107,6 +108,7 @@ class AutoTuner:
fn: The function to be auto-tuned. fn: The function to be auto-tuned.
configs: List of configurations to try during auto-tuning. configs: List of configurations to try during auto-tuning.
""" """
compile_args = CompileArgs() compile_args = CompileArgs()
profile_args = ProfileArgs() profile_args = ProfileArgs()
...@@ -137,14 +139,15 @@ class AutoTuner: ...@@ -137,14 +139,15 @@ class AutoTuner:
""" """
return cls(kernel, configs) return cls(kernel, configs)
def set_compile_args(self, def set_compile_args(
out_idx: list[int] | int | None = None, self,
target: Literal['auto', 'cuda', 'hip', 'metal'] = 'auto', out_idx: list[int] | int | None = None,
execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", target: Literal["auto", "cuda", "hip", "metal"] = "auto",
"torch"] = "auto", execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto",
target_host: str | Target = None, target_host: str | Target = None,
verbose: bool = False, verbose: bool = False,
pass_configs: dict[str, Any] | None = None): pass_configs: dict[str, Any] | None = None,
):
"""Set compilation arguments for the auto-tuner. """Set compilation arguments for the auto-tuner.
Args: Args:
...@@ -161,6 +164,7 @@ class AutoTuner: ...@@ -161,6 +164,7 @@ class AutoTuner:
# Normalize target to a concrete TVM Target and resolve execution backend # Normalize target to a concrete TVM Target and resolve execution backend
t = Target(determine_target(target)) t = Target(determine_target(target))
from tilelang.jit.execution_backend import resolve_execution_backend from tilelang.jit.execution_backend import resolve_execution_backend
resolved_backend = resolve_execution_backend(execution_backend, t) resolved_backend = resolve_execution_backend(execution_backend, t)
self.compile_args = CompileArgs( self.compile_args = CompileArgs(
...@@ -169,23 +173,26 @@ class AutoTuner: ...@@ -169,23 +173,26 @@ class AutoTuner:
execution_backend=resolved_backend, execution_backend=resolved_backend,
target_host=target_host, target_host=target_host,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs) pass_configs=pass_configs,
)
return self return self
def set_profile_args(self, def set_profile_args(
warmup: int = 25, self,
rep: int = 100, warmup: int = 25,
timeout: int = 30, rep: int = 100,
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto, timeout: int = 30,
ref_prog: Callable = None, supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto,
supply_prog: Callable = None, ref_prog: Callable = None,
rtol: float = 1e-2, supply_prog: Callable = None,
atol: float = 1e-2, rtol: float = 1e-2,
max_mismatched_ratio: float = 0.01, atol: float = 1e-2,
skip_check: bool = False, max_mismatched_ratio: float = 0.01,
manual_check_prog: Callable = None, skip_check: bool = False,
cache_input_tensors: bool = False): manual_check_prog: Callable = None,
cache_input_tensors: bool = False,
):
"""Set profiling arguments for the auto-tuner. """Set profiling arguments for the auto-tuner.
Args: Args:
...@@ -209,9 +216,7 @@ class AutoTuner: ...@@ -209,9 +216,7 @@ class AutoTuner:
# the `supply_prog` will be ignored and the `get_autotune_inputs` will be used instead. # the `supply_prog` will be ignored and the `get_autotune_inputs` will be used instead.
if get_autotune_inputs() is not None: if get_autotune_inputs() is not None:
if supply_prog is not None: if supply_prog is not None:
logger.warning( logger.warning("`supply_prog` will be ignored as this program is under `with set_autotune_inputs` context.")
"`supply_prog` will be ignored as this program is under `with set_autotune_inputs` context."
)
supply_prog = lambda _: get_autotune_inputs() # noqa: E731 supply_prog = lambda _: get_autotune_inputs() # noqa: E731
self.profile_args = ProfileArgs( self.profile_args = ProfileArgs(
...@@ -226,13 +231,13 @@ class AutoTuner: ...@@ -226,13 +231,13 @@ class AutoTuner:
cache_input_tensors=cache_input_tensors, cache_input_tensors=cache_input_tensors,
warmup=warmup, warmup=warmup,
rep=rep, rep=rep,
timeout=timeout) timeout=timeout,
)
# If a custom `supply_prog` is provided, the profiler's `supply_type` setting # If a custom `supply_prog` is provided, the profiler's `supply_type` setting
# becomes ineffective. The custom supply program will be used instead. # becomes ineffective. The custom supply program will be used instead.
if supply_prog is not None and supply_type != tilelang.TensorSupplyType.Auto: if supply_prog is not None and supply_type != tilelang.TensorSupplyType.Auto:
logger.warning("Ignoring `supply_type` passed to `set_profile_args` because " logger.warning("Ignoring `supply_type` passed to `set_profile_args` because `supply_prog` is not None.")
"`supply_prog` is not None.")
return self return self
...@@ -241,10 +246,8 @@ class AutoTuner: ...@@ -241,10 +246,8 @@ class AutoTuner:
self._kernel_parameters = k_parameters self._kernel_parameters = k_parameters
self._function_parameters = f_parameters self._function_parameters = f_parameters
def generate_cache_key(self, parameters: dict[str, Any], def generate_cache_key(self, parameters: dict[str, Any], extra_parameters: dict[str, Any]) -> AutotuneResult | None:
extra_parameters: dict[str, Any]) -> AutotuneResult | None: """Generate a cache key for the auto-tuning process."""
"""Generate a cache key for the auto-tuning process.
"""
def _normalize_param(value): def _normalize_param(value):
if isinstance(value, Var): if isinstance(value, Var):
...@@ -315,8 +318,9 @@ class AutoTuner: ...@@ -315,8 +318,9 @@ class AutoTuner:
if var_name in parameters: if var_name in parameters:
continue continue
# Cell content must be serializable # Cell content must be serializable
assert isinstance(cell.cell_contents, (int, float, str, bool, type(None))), \ assert isinstance(cell.cell_contents, (int, float, str, bool, type(None))), (
f"Cell contents {cell.cell_contents} is not serializable: {type(cell.cell_contents)}" f"Cell contents {cell.cell_contents} is not serializable: {type(cell.cell_contents)}"
)
extra_parameters[var_name] = cell.cell_contents extra_parameters[var_name] = cell.cell_contents
if isinstance(self.configs, Callable): if isinstance(self.configs, Callable):
...@@ -328,8 +332,10 @@ class AutoTuner: ...@@ -328,8 +332,10 @@ class AutoTuner:
if env.is_cache_enabled() and not env.is_autotune_cache_disabled(): if env.is_cache_enabled() and not env.is_autotune_cache_disabled():
# First check in-memory cache # First check in-memory cache
if key in self._memory_cache: if key in self._memory_cache:
logger.warning("Found kernel in memory cache. For better performance," \ logger.warning(
" consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel.") "Found kernel in memory cache. For better performance,"
" consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel."
)
return self._memory_cache[key] return self._memory_cache[key]
# Then check disk cache # Then check disk cache
...@@ -369,7 +375,6 @@ class AutoTuner: ...@@ -369,7 +375,6 @@ class AutoTuner:
# This encapsulates the logic of using either a custom supply program (`supply_prog`) # This encapsulates the logic of using either a custom supply program (`supply_prog`)
# or the default profiler input generation (`profiler._get_inputs`). # or the default profiler input generation (`profiler._get_inputs`).
def get_input_tensors_supply(with_output: bool): def get_input_tensors_supply(with_output: bool):
def func(): def func():
if supply_prog is not None: if supply_prog is not None:
return supply_prog(profiler._get_params(with_output=with_output)) return supply_prog(profiler._get_params(with_output=with_output))
...@@ -387,8 +392,7 @@ class AutoTuner: ...@@ -387,8 +392,7 @@ class AutoTuner:
self.jit_input_tensors = jit_input_tensors_supply() self.jit_input_tensors = jit_input_tensors_supply()
else: else:
# check if the cached tensors are compatible with the current configuration # check if the cached tensors are compatible with the current configuration
assert len(params) == len( assert len(params) == len(self.jit_input_tensors), "len(params) != len(self.jit_input_tensors)"
self.jit_input_tensors), "len(params) != len(self.jit_input_tensors)"
for p, c in zip(params, self.jit_input_tensors): for p, c in zip(params, self.jit_input_tensors):
if not isinstance(c, torch.Tensor): if not isinstance(c, torch.Tensor):
# skip non-tensor inputs checking # skip non-tensor inputs checking
...@@ -397,8 +401,8 @@ class AutoTuner: ...@@ -397,8 +401,8 @@ class AutoTuner:
# Check tensor compatibility using generator expression # Check tensor compatibility using generator expression
def shape_equal(a, b): def shape_equal(a, b):
return all( return all(
a_dim == b_dim or isinstance(a_dim, Var) or isinstance(b_dim, Var) a_dim == b_dim or isinstance(a_dim, Var) or isinstance(b_dim, Var) for a_dim, b_dim in zip(a.shape, b.shape)
for a_dim, b_dim in zip(a.shape, b.shape)) )
if p.dtype != c.dtype or not shape_equal(p, c): if p.dtype != c.dtype or not shape_equal(p, c):
logger.warning( logger.warning(
...@@ -409,7 +413,8 @@ class AutoTuner: ...@@ -409,7 +413,8 @@ class AutoTuner:
"To ensure fresh, compatible inputs are generated for every trial " "To ensure fresh, compatible inputs are generated for every trial "
"you can disable caching by setting:\n" "you can disable caching by setting:\n"
" `cache_input_tensors=False`\n" " `cache_input_tensors=False`\n"
"within your `.set_compile_args(...)` call.\n") "within your `.set_compile_args(...)` call.\n"
)
# otherwise, regenerate the input tensors for safety # otherwise, regenerate the input tensors for safety
self.jit_input_tensors = jit_input_tensors_supply() self.jit_input_tensors = jit_input_tensors_supply()
break break
...@@ -418,24 +423,16 @@ class AutoTuner: ...@@ -418,24 +423,16 @@ class AutoTuner:
if (not skip_check) and (ref_prog is not None): if (not skip_check) and (ref_prog is not None):
if manual_check_prog is not None: if manual_check_prog is not None:
profiler.manual_assert_close( profiler.manual_assert_close(ref_prog, input_tensors=self.jit_input_tensors, manual_check_prog=manual_check_prog)
ref_prog,
input_tensors=self.jit_input_tensors,
manual_check_prog=manual_check_prog)
else: else:
profiler.assert_allclose( profiler.assert_allclose(
ref_prog, ref_prog, input_tensors=self.jit_input_tensors, rtol=rtol, atol=atol, max_mismatched_ratio=max_mismatched_ratio
input_tensors=self.jit_input_tensors, )
rtol=rtol, latency = profiler.do_bench(warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors)
atol=atol,
max_mismatched_ratio=max_mismatched_ratio)
latency = profiler.do_bench(
warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors)
if self.ref_latency_cache is None and ref_prog is not None: if self.ref_latency_cache is None and ref_prog is not None:
self.ref_input_tensors = ref_input_tensors_supply() self.ref_input_tensors = ref_input_tensors_supply()
self.ref_latency_cache = profiler.do_bench( self.ref_latency_cache = profiler.do_bench(ref_prog, n_warmup=warmup, n_repeat=rep, input_tensors=self.ref_input_tensors)
ref_prog, n_warmup=warmup, n_repeat=rep, input_tensors=self.ref_input_tensors)
return latency, self.ref_latency_cache return latency, self.ref_latency_cache
...@@ -469,17 +466,14 @@ class AutoTuner: ...@@ -469,17 +466,14 @@ class AutoTuner:
# Check if all tunable arguments have been tuned by comparing config keys with key_kwargs_tuple # Check if all tunable arguments have been tuned by comparing config keys with key_kwargs_tuple
if any(key in top_config for key, _ in key_kwargs_tuple) or any( if any(key in top_config for key, _ in key_kwargs_tuple) or any(
check_tunable_argument_value(key, self._function_parameters, key_args_tuple) check_tunable_argument_value(key, self._function_parameters, key_args_tuple) for key in tunable_arguments
for key in tunable_arguments): ):
logger.warning( logger.warning(
f"Tunable parameters {tunable_arguments} already provided during auto-tuning. Skipping compilation and using direct JIT" f"Tunable parameters {tunable_arguments} already provided during auto-tuning. Skipping compilation and using direct JIT"
) )
# compile the kernel with the provided parameters # compile the kernel with the provided parameters
jit_kernel = self.jit_compile() jit_kernel = self.jit_compile()
autotuner_result = AutotuneResult( autotuner_result = AutotuneResult(libcode=jit_kernel.get_kernel_source(), func=jit_kernel.prim_func, kernel=jit_kernel)
libcode=jit_kernel.get_kernel_source(),
func=jit_kernel.prim_func,
kernel=jit_kernel)
self._memory_cache[key] = autotuner_result self._memory_cache[key] = autotuner_result
return autotuner_result return autotuner_result
# get the cpu count # get the cpu count
...@@ -489,9 +483,7 @@ class AutoTuner: ...@@ -489,9 +483,7 @@ class AutoTuner:
max_cpu_count = int(env.TILELANG_AUTO_TUNING_MAX_CPU_COUNT) max_cpu_count = int(env.TILELANG_AUTO_TUNING_MAX_CPU_COUNT)
if cpu_counts > 0: if cpu_counts > 0:
num_workers = min(cpu_counts, available_cpu_count) num_workers = min(cpu_counts, available_cpu_count)
logger.info( logger.info(f"Auto-tuning with {cpu_counts} CPU counts, {available_cpu_count} CPUs available, {num_workers} CPUs will be used")
f"Auto-tuning with {cpu_counts} CPU counts, {available_cpu_count} CPUs available, {num_workers} CPUs will be used"
)
else: else:
num_workers = max(1, int(available_cpu_count * cpu_utilizations)) num_workers = max(1, int(available_cpu_count * cpu_utilizations))
logger.info( logger.info(
...@@ -509,7 +501,6 @@ class AutoTuner: ...@@ -509,7 +501,6 @@ class AutoTuner:
future_to_index = {} future_to_index = {}
def cuda_device_wrapper(func, device): def cuda_device_wrapper(func, device):
def inner(**config_arg): def inner(**config_arg):
torch.cuda.set_device(device) torch.cuda.set_device(device)
return func(**config_arg) return func(**config_arg)
...@@ -532,18 +523,14 @@ class AutoTuner: ...@@ -532,18 +523,14 @@ class AutoTuner:
future_to_index[future] = i future_to_index[future] = i
results_with_configs = [] results_with_configs = []
for future in tqdm( for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Compiling configurations"):
concurrent.futures.as_completed(futures),
total=len(futures),
desc="Compiling configurations"):
idx = future_to_index[future] idx = future_to_index[future]
config = config_args[idx] config = config_args[idx]
try: try:
result = future.result() result = future.result()
results_with_configs.append((result, config)) results_with_configs.append((result, config))
except Exception as e: except Exception as e:
logger.debug( logger.debug(f"Compilation failed for config {config} at index {idx} with error: {e}")
f"Compilation failed for config {config} at index {idx} with error: {e}")
continue continue
ref_latency = None ref_latency = None
...@@ -556,14 +543,10 @@ class AutoTuner: ...@@ -556,14 +543,10 @@ class AutoTuner:
# latency, ref_latency = target_fn(jit_kernel) # latency, ref_latency = target_fn(jit_kernel)
latency, ref_latency = run_with_timeout(target_fn, timeout, jit_kernel) latency, ref_latency = run_with_timeout(target_fn, timeout, jit_kernel)
except TimeoutException: except TimeoutException:
logger.warning( logger.warning(f"A timeout occurred while testing config {config}, checkout autotuner.log for more details")
f"A timeout occurred while testing config {config}, checkout autotuner.log for more details"
)
continue continue
except Exception: except Exception:
logger.warning( logger.warning(f"An error occurred while testing config {config}, checkout autotuner.log for more details")
f"An error occurred while testing config {config}, checkout autotuner.log for more details"
)
logger.debug(f"Error: {traceback.format_exc()}") logger.debug(f"Error: {traceback.format_exc()}")
continue continue
...@@ -578,8 +561,7 @@ class AutoTuner: ...@@ -578,8 +561,7 @@ class AutoTuner:
pool.shutdown() pool.shutdown()
if best_kernel is None: if best_kernel is None:
error_msg = ("Auto-tuning failed: No configuration successfully " error_msg = "Auto-tuning failed: No configuration successfully compiled and passed benchmarking/validation."
"compiled and passed benchmarking/validation.")
logger.error(error_msg) logger.error(error_msg)
raise RuntimeError(error_msg) raise RuntimeError(error_msg)
...@@ -595,7 +577,8 @@ class AutoTuner: ...@@ -595,7 +577,8 @@ class AutoTuner:
ref_latency=ref_latency, ref_latency=ref_latency,
libcode=best_kernel.get_kernel_source(), libcode=best_kernel.get_kernel_source(),
func=best_kernel.prim_func, func=best_kernel.prim_func,
kernel=best_kernel) kernel=best_kernel,
)
if self.compile_args.execution_backend in ("torch"): if self.compile_args.execution_backend in ("torch"):
logger.warning("DLPack backend does not support cache saving to disk.") logger.warning("DLPack backend does not support cache saving to disk.")
...@@ -617,8 +600,8 @@ class AutoTuner: ...@@ -617,8 +600,8 @@ class AutoTuner:
return self.run() return self.run()
_P = ParamSpec('_P') _P = ParamSpec("_P")
_T = TypeVar('_T') _T = TypeVar("_T")
@dataclass @dataclass
...@@ -643,8 +626,9 @@ class AutoTuneImpl(Generic[_P, _T]): ...@@ -643,8 +626,9 @@ class AutoTuneImpl(Generic[_P, _T]):
self._tuner_cache = {} self._tuner_cache = {}
def get_tunner(self): def get_tunner(self):
autotuner = AutoTuner( autotuner = (
self.jit_impl.func, configs=self.configs).set_profile_args( AutoTuner(self.jit_impl.func, configs=self.configs)
.set_profile_args(
supply_type=self.supply_type, supply_type=self.supply_type,
ref_prog=self.ref_prog, ref_prog=self.ref_prog,
supply_prog=self.supply_prog, supply_prog=self.supply_prog,
...@@ -654,7 +638,8 @@ class AutoTuneImpl(Generic[_P, _T]): ...@@ -654,7 +638,8 @@ class AutoTuneImpl(Generic[_P, _T]):
skip_check=self.skip_check, skip_check=self.skip_check,
manual_check_prog=self.manual_check_prog, manual_check_prog=self.manual_check_prog,
cache_input_tensors=self.cache_input_tensors, cache_input_tensors=self.cache_input_tensors,
).set_compile_args( )
.set_compile_args(
out_idx=self.jit_impl.out_idx, out_idx=self.jit_impl.out_idx,
execution_backend=self.jit_impl.execution_backend, execution_backend=self.jit_impl.execution_backend,
target=self.jit_impl.target, target=self.jit_impl.target,
...@@ -662,6 +647,7 @@ class AutoTuneImpl(Generic[_P, _T]): ...@@ -662,6 +647,7 @@ class AutoTuneImpl(Generic[_P, _T]):
verbose=self.jit_impl.verbose, verbose=self.jit_impl.verbose,
pass_configs=self.jit_impl.pass_configs, pass_configs=self.jit_impl.pass_configs,
) )
)
autotuner.run = partial(autotuner.run, self.warmup, self.rep, self.timeout) autotuner.run = partial(autotuner.run, self.warmup, self.rep, self.timeout)
return autotuner return autotuner
...@@ -753,16 +739,13 @@ def autotune( # This is the new public interface ...@@ -753,16 +739,13 @@ def autotune( # This is the new public interface
if callable(func): if callable(func):
# Case 1: Used as @autotune (func_or_out_idx is the function, others are defaults) # Case 1: Used as @autotune (func_or_out_idx is the function, others are defaults)
# This is a placeholder for a real auto tuner implementation # This is a placeholder for a real auto tuner implementation
raise ValueError( raise ValueError("Use tilelang.autotune to decorate func without arguments is not supported yet.")
"Use tilelang.autotune to decorate func without arguments is not supported yet.")
elif isinstance(func, PrimFunc): elif isinstance(func, PrimFunc):
raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.") raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.")
else: else:
def decorator(impl): def decorator(impl):
assert isinstance( assert isinstance(impl, JITImpl), "The @autotune decorator can only be applied to @tilelang.jit decorated instances."
impl, JITImpl
), "The @autotune decorator can only be applied to @tilelang.jit decorated instances."
return AutoTuneImpl( return AutoTuneImpl(
jit_impl=impl, jit_impl=impl,
configs=configs, configs=configs,
......
"""The cache utils with class and database persistence - Init file""" """The cache utils with class and database persistence - Init file"""
from __future__ import annotations from __future__ import annotations
from typing import Literal from typing import Literal
...@@ -18,8 +19,7 @@ def cached( ...@@ -18,8 +19,7 @@ def cached(
*args, *args,
target: str | Target = "auto", target: str | Target = "auto",
target_host: str | Target = None, target_host: str | Target = None,
execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] | None = "auto",
| None = "auto",
verbose: bool | None = False, verbose: bool | None = False,
pass_configs: dict | None = None, pass_configs: dict | None = None,
compile_flags: list[str] | str | None = None, compile_flags: list[str] | str | None = None,
...@@ -36,7 +36,8 @@ def cached( ...@@ -36,7 +36,8 @@ def cached(
execution_backend=execution_backend, execution_backend=execution_backend,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs, pass_configs=pass_configs,
compile_flags=compile_flags) compile_flags=compile_flags,
)
def clear_cache(): def clear_cache():
...@@ -47,9 +48,11 @@ def clear_cache(): ...@@ -47,9 +48,11 @@ def clear_cache():
RuntimeError: Always raised to warn users to clear the cache manually. RuntimeError: Always raised to warn users to clear the cache manually.
""" """
cache_dir = env.TILELANG_CACHE_DIR cache_dir = env.TILELANG_CACHE_DIR
raise RuntimeError("tilelang.clear_cache() is disabled because deleting the cache directory " raise RuntimeError(
"is dangerous. If you accept the risk, remove it manually with " "tilelang.clear_cache() is disabled because deleting the cache directory "
f"`rm -rf '{cache_dir}'`.") "is dangerous. If you accept the risk, remove it manually with "
f"`rm -rf '{cache_dir}'`."
)
if env.TILELANG_CLEAR_CACHE.lower() in ("1", "true", "yes", "on"): if env.TILELANG_CLEAR_CACHE.lower() in ("1", "true", "yes", "on"):
......
"""The cache utils with class and database persistence - KernelCache Class""" """The cache utils with class and database persistence - KernelCache Class"""
from __future__ import annotations from __future__ import annotations
import json import json
...@@ -97,9 +98,7 @@ class KernelCache: ...@@ -97,9 +98,7 @@ class KernelCache:
"version": __version__, "version": __version__,
"func": sha256(func_binary).hexdigest(), # Use SHA256 to generate hash key "func": sha256(func_binary).hexdigest(), # Use SHA256 to generate hash key
"out_idx": (tuple(out_idx) if isinstance(out_idx, (list, tuple)) else [out_idx]), "out_idx": (tuple(out_idx) if isinstance(out_idx, (list, tuple)) else [out_idx]),
"args_repr": tuple( "args_repr": tuple(repr(arg) for arg in args), # Use repr to serialize arguments, may need more robust serialization
repr(arg) for arg in args
), # Use repr to serialize arguments, may need more robust serialization
"target": str(target), "target": str(target),
"target_host": str(target_host) if target_host else None, "target_host": str(target_host) if target_host else None,
"execution_backend": execution_backend, "execution_backend": execution_backend,
...@@ -118,8 +117,7 @@ class KernelCache: ...@@ -118,8 +117,7 @@ class KernelCache:
*args, *args,
target: str | Target = "auto", target: str | Target = "auto",
target_host: str | Target = None, target_host: str | Target = None,
execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto",
"torch"] = "auto",
verbose: bool = False, verbose: bool = False,
pass_configs: dict = None, pass_configs: dict = None,
compile_flags: list[str] | str | None = None, compile_flags: list[str] | str | None = None,
...@@ -140,6 +138,7 @@ class KernelCache: ...@@ -140,6 +138,7 @@ class KernelCache:
# Normalize target and resolve execution backend before proceeding # Normalize target and resolve execution backend before proceeding
from tilelang.utils.target import determine_target as _determine_target from tilelang.utils.target import determine_target as _determine_target
from tilelang.jit.execution_backend import resolve_execution_backend, allowed_backends_for_target from tilelang.jit.execution_backend import resolve_execution_backend, allowed_backends_for_target
norm_target = Target(_determine_target(target)) if isinstance(target, str) else target norm_target = Target(_determine_target(target)) if isinstance(target, str) else target
requested_backend = execution_backend requested_backend = execution_backend
execution_backend = resolve_execution_backend(requested_backend, norm_target) execution_backend = resolve_execution_backend(requested_backend, norm_target)
...@@ -180,21 +179,21 @@ class KernelCache: ...@@ -180,21 +179,21 @@ class KernelCache:
with self._lock: with self._lock:
# First check in-memory cache # First check in-memory cache
if key in self._memory_cache: if key in self._memory_cache:
self.logger.warning("Found kernel in memory cache. For better performance," \ self.logger.warning(
" consider using `@tilelang.jit` instead of direct kernel caching.") "Found kernel in memory cache. For better performance, consider using `@tilelang.jit` instead of direct kernel caching."
)
return self._memory_cache[key] return self._memory_cache[key]
if verbose: if verbose:
self.logger.debug(f"Checking disk cache for kernel {func.attrs['global_symbol']}") self.logger.debug(f"Checking disk cache for kernel {func.attrs['global_symbol']}")
# Then check disk cache # Then check disk cache
kernel = self._load_kernel_from_disk(key, norm_target, target_host, out_idx, kernel = self._load_kernel_from_disk(
execution_backend, pass_configs, compile_flags, key, norm_target, target_host, out_idx, execution_backend, pass_configs, compile_flags, func, verbose
func, verbose) )
if kernel is not None: if kernel is not None:
if verbose: if verbose:
self.logger.debug( self.logger.debug(f"Found kernel in disk cache for {func.attrs['global_symbol']}")
f"Found kernel in disk cache for {func.attrs['global_symbol']}")
# Populate memory cache with disk result # Populate memory cache with disk result
self._memory_cache[key] = kernel self._memory_cache[key] = kernel
return kernel return kernel
...@@ -262,11 +261,7 @@ class KernelCache: ...@@ -262,11 +261,7 @@ class KernelCache:
executable.export_library(temp_path) executable.export_library(temp_path)
os.replace(temp_path, path) os.replace(temp_path, path)
def _save_kernel_to_disk(self, def _save_kernel_to_disk(self, key: str, kernel: JITKernel, func: Callable = None, verbose: bool = False):
key: str,
kernel: JITKernel,
func: Callable = None,
verbose: bool = False):
""" """
Persists a compiled kernel to disk cache. Persists a compiled kernel to disk cache.
...@@ -292,8 +287,7 @@ class KernelCache: ...@@ -292,8 +287,7 @@ class KernelCache:
if verbose: if verbose:
self.logger.debug(f"Saving kernel source code to file: {device_kernel_path}") self.logger.debug(f"Saving kernel source code to file: {device_kernel_path}")
if kernel.kernel_source is not None: if kernel.kernel_source is not None:
KernelCache._safe_write_file(device_kernel_path, "w", KernelCache._safe_write_file(device_kernel_path, "w", lambda file: file.write(kernel.kernel_source))
lambda file: file.write(kernel.kernel_source))
except Exception as e: except Exception as e:
self.logger.error(f"Error saving kernel source code to disk: {e}") self.logger.error(f"Error saving kernel source code to disk: {e}")
...@@ -303,13 +297,9 @@ class KernelCache: ...@@ -303,13 +297,9 @@ class KernelCache:
if verbose: if verbose:
self.logger.debug(f"Saving wrapped kernel source code to file: {host_kernel_path}") self.logger.debug(f"Saving wrapped kernel source code to file: {host_kernel_path}")
if self.execution_backend == "tvm_ffi": if self.execution_backend == "tvm_ffi":
KernelCache._safe_write_file( KernelCache._safe_write_file(host_kernel_path, "w", lambda file: file.write(kernel.adapter.get_host_source()))
host_kernel_path, "w",
lambda file: file.write(kernel.adapter.get_host_source()))
else: else:
KernelCache._safe_write_file( KernelCache._safe_write_file(host_kernel_path, "w", lambda file: file.write(kernel.adapter.get_kernel_source()))
host_kernel_path, "w",
lambda file: file.write(kernel.adapter.get_kernel_source()))
except Exception as e: except Exception as e:
self.logger.error(f"Error saving host kernel source code to disk: {e}") self.logger.error(f"Error saving host kernel source code to disk: {e}")
...@@ -332,9 +322,7 @@ class KernelCache: ...@@ -332,9 +322,7 @@ class KernelCache:
src_lib_path = src_lib_path.replace(".cubin", ".py") src_lib_path = src_lib_path.replace(".cubin", ".py")
if verbose: if verbose:
self.logger.debug(f"Saving kernel nvrtc python code to file: {kernel_py_path}") self.logger.debug(f"Saving kernel nvrtc python code to file: {kernel_py_path}")
KernelCache._safe_write_file( KernelCache._safe_write_file(kernel_py_path, "wb", lambda file: file.write(KernelCache._load_binary(src_lib_path)))
kernel_py_path, "wb",
lambda file: file.write(KernelCache._load_binary(src_lib_path)))
elif self.execution_backend == "tvm_ffi": elif self.execution_backend == "tvm_ffi":
executable = kernel.adapter.executable executable = kernel.adapter.executable
if verbose: if verbose:
...@@ -344,9 +332,7 @@ class KernelCache: ...@@ -344,9 +332,7 @@ class KernelCache:
src_lib_path = kernel.adapter.libpath src_lib_path = kernel.adapter.libpath
if verbose: if verbose:
self.logger.debug(f"Saving kernel library to file: {kernel_lib_path}") self.logger.debug(f"Saving kernel library to file: {kernel_lib_path}")
KernelCache._safe_write_file( KernelCache._safe_write_file(kernel_lib_path, "wb", lambda file: file.write(KernelCache._load_binary(src_lib_path)))
kernel_lib_path, "wb",
lambda file: file.write(KernelCache._load_binary(src_lib_path)))
except Exception as e: except Exception as e:
self.logger.error(f"Error saving kernel library to disk: {e}") self.logger.error(f"Error saving kernel library to disk: {e}")
...@@ -356,8 +342,7 @@ class KernelCache: ...@@ -356,8 +342,7 @@ class KernelCache:
params_path = os.path.join(cache_path, PARAMS_PATH) params_path = os.path.join(cache_path, PARAMS_PATH)
if verbose: if verbose:
self.logger.debug(f"Saving kernel parameters to disk: {params_path}") self.logger.debug(f"Saving kernel parameters to disk: {params_path}")
KernelCache._safe_write_file(params_path, "wb", KernelCache._safe_write_file(params_path, "wb", lambda file: cloudpickle.dump(kernel.params, file))
lambda file: cloudpickle.dump(kernel.params, file))
except Exception as e: except Exception as e:
self.logger.error(f"Error saving kernel parameters to disk: {e}") self.logger.error(f"Error saving kernel parameters to disk: {e}")
...@@ -417,8 +402,7 @@ class KernelCache: ...@@ -417,8 +402,7 @@ class KernelCache:
self.logger.error(f"Error loading kernel source code from disk: {e}") self.logger.error(f"Error loading kernel source code from disk: {e}")
try: try:
if verbose: if verbose:
self.logger.debug( self.logger.debug(f"Loading wrapped kernel source code from file: {host_kernel_path}")
f"Loading wrapped kernel source code from file: {host_kernel_path}")
with open(host_kernel_path) as f: with open(host_kernel_path) as f:
host_kernel_source = f.read() host_kernel_source = f.read()
except Exception as e: except Exception as e:
......
"""Base infra""" """Base infra"""
from .analysis import ( from .analysis import (
BlockInfo, # noqa: F401 BlockInfo, # noqa: F401
IterInfo, # noqa: F401 IterInfo, # noqa: F401
......
"""Analysis on TIR blocks, loops and functions.""" """Analysis on TIR blocks, loops and functions."""
from __future__ import annotations from __future__ import annotations
from typing_extensions import Literal from typing_extensions import Literal
...@@ -144,11 +145,13 @@ def normalize_prim_func(sch: tir.Schedule) -> list[BlockInfo] | None: ...@@ -144,11 +145,13 @@ def normalize_prim_func(sch: tir.Schedule) -> list[BlockInfo] | None:
var=iter.var, var=iter.var,
dom=iter.dom, dom=iter.dom,
loop_rv=loop, loop_rv=loop,
) for loop, iter in zip(loops, iters) )
for loop, iter in zip(loops, iters)
], ],
block_rv=block, block_rv=block,
reduction_block=is_reduction, reduction_block=is_reduction,
)) )
)
return blocks return blocks
...@@ -188,8 +191,7 @@ def get_max_shared_memory_per_block(target: Target) -> int: ...@@ -188,8 +191,7 @@ def get_max_shared_memory_per_block(target: Target) -> int:
_assert_gpu_target(target) _assert_gpu_target(target)
max_shared_memory_per_block = target.attrs.get("max_shared_memory_per_block", None) max_shared_memory_per_block = target.attrs.get("max_shared_memory_per_block", None)
if max_shared_memory_per_block is None: if max_shared_memory_per_block is None:
raise ValueError( raise ValueError(f"Cannot find `max_shared_memory_per_block` in {target}, please specify it manually")
f"Cannot find `max_shared_memory_per_block` in {target}, please specify it manually")
return int(max_shared_memory_per_block) return int(max_shared_memory_per_block)
...@@ -197,13 +199,11 @@ def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV: ...@@ -197,13 +199,11 @@ def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV:
try: try:
block = sch.mod[func_name].body.block block = sch.mod[func_name].body.block
except Exception: except Exception:
raise ValueError(f"The function body is expected to be the root block, but got:\n" raise ValueError(f"The function body is expected to be the root block, but got:\n{sch.mod[func_name].body}") from None
f"{sch.mod[func_name].body}") from None
return sch.get_block(block.name_hint) return sch.get_block(block.name_hint)
def collect_block_iter_vars_used_in_access_region(block: tir.Block, def collect_block_iter_vars_used_in_access_region(block: tir.Block, region: list[ir.Range]) -> set[tir.Var]:
region: list[ir.Range]) -> set[tir.Var]:
"""Collect the block iter variables used in the access region of a buffer region.""" """Collect the block iter variables used in the access region of a buffer region."""
tir_vars = set() tir_vars = set()
for expr in region: for expr in region:
...@@ -251,15 +251,13 @@ def is_broadcast_epilogue( ...@@ -251,15 +251,13 @@ def is_broadcast_epilogue(
for buffer_region in sch.get(epilogue).reads: for buffer_region in sch.get(epilogue).reads:
if buffer_region.buffer not in write_buffers: if buffer_region.buffer not in write_buffers:
continue continue
tir_vars = collect_block_iter_vars_used_in_access_region( tir_vars = collect_block_iter_vars_used_in_access_region(sch.get(epilogue), buffer_region.region)
sch.get(epilogue), buffer_region.region)
if len(tir_vars) < len(epilogue_iters): if len(tir_vars) < len(epilogue_iters):
return True return True
return False return False
def get_reduction_blocks(sch: tir.Schedule, def get_reduction_blocks(sch: tir.Schedule, blocks: list[tir.schedule.BlockRV]) -> list[tir.schedule.BlockRV]:
blocks: list[tir.schedule.BlockRV]) -> list[tir.schedule.BlockRV]:
# Get the main computation block # Get the main computation block
def is_reduction(block: BlockRV) -> bool: def is_reduction(block: BlockRV) -> bool:
block_stmt = sch.get(block) block_stmt = sch.get(block)
......
...@@ -39,18 +39,18 @@ def auto_infer_current_arch() -> TileDevice: ...@@ -39,18 +39,18 @@ def auto_infer_current_arch() -> TileDevice:
__all__ = [ __all__ = [
'is_cpu_arch', "is_cpu_arch",
'is_cuda_arch', "is_cuda_arch",
'is_volta_arch', "is_volta_arch",
'is_ampere_arch', "is_ampere_arch",
'is_ada_arch', "is_ada_arch",
'is_hopper_arch', "is_hopper_arch",
'is_tensorcore_supported_precision', "is_tensorcore_supported_precision",
'has_mma_support', "has_mma_support",
'is_cdna_arch', "is_cdna_arch",
'is_metal_arch', "is_metal_arch",
'CUDA', "CUDA",
'CDNA', "CDNA",
'METAL', "METAL",
'CPU', "CPU",
] ]
...@@ -7,9 +7,7 @@ class TileDevice: ...@@ -7,9 +7,7 @@ class TileDevice:
self.reg_cap: int = 0 # Register capacity: The amount of register memory available self.reg_cap: int = 0 # Register capacity: The amount of register memory available
self.smem_cap: int = 0 # Shared memory capacity: The amount of shared memory available self.smem_cap: int = 0 # Shared memory capacity: The amount of shared memory available
self.compute_max_core: int = 0 # The maximum number of computing cores self.compute_max_core: int = 0 # The maximum number of computing cores
self.warp_size: int = ( self.warp_size: int = 0 # The size of a warp, a group of threads that execute instructions in lockstep
0 # The size of a warp, a group of threads that execute instructions in lockstep
)
self.sm_partition: int = 0 # The number of streaming multiprocessor partitions self.sm_partition: int = 0 # The number of streaming multiprocessor partitions
self.transaction_size: list[int] = [ self.transaction_size: list[int] = [
0, 0,
...@@ -21,9 +19,7 @@ class TileDevice: ...@@ -21,9 +19,7 @@ class TileDevice:
0, 0,
] # Bandwidth specifications, possibly including peak and sustained rates ] # Bandwidth specifications, possibly including peak and sustained rates
self.platform: str = "unknown" # The platform or manufacturer of the device self.platform: str = "unknown" # The platform or manufacturer of the device
self.compute_capability: str = ( self.compute_capability: str = "unknown" # The compute capability, indicating the feature set and performance level
"unknown" # The compute capability, indicating the feature set and performance level
)
self.l2_cache_size_bytes: int = 0 self.l2_cache_size_bytes: int = 0
# the number of transaction size in bytes # the number of transaction size in bytes
self.transaction_size: list[int] = [0, 0] # in bytes self.transaction_size: list[int] = [0, 0] # in bytes
......
...@@ -9,7 +9,6 @@ def is_cdna_arch(arch: TileDevice) -> bool: ...@@ -9,7 +9,6 @@ def is_cdna_arch(arch: TileDevice) -> bool:
class CDNA(TileDevice): class CDNA(TileDevice):
def __init__(self, target: Target | str): def __init__(self, target: Target | str):
if isinstance(target, str): if isinstance(target, str):
target = tvm.target.Target(target) target = tvm.target.Target(target)
...@@ -33,6 +32,6 @@ class CDNA(TileDevice): ...@@ -33,6 +32,6 @@ class CDNA(TileDevice):
__all__ = [ __all__ = [
'is_cdna_arch', "is_cdna_arch",
'CDNA', "CDNA",
] ]
...@@ -10,7 +10,6 @@ def is_cpu_arch(arch: TileDevice) -> bool: ...@@ -10,7 +10,6 @@ def is_cpu_arch(arch: TileDevice) -> bool:
# For LLVM Backend, we do not provide the detailed information of the CPU # For LLVM Backend, we do not provide the detailed information of the CPU
# As the LLVM backend do not required tuning, just maintain the consistency # As the LLVM backend do not required tuning, just maintain the consistency
class CPU(TileDevice): class CPU(TileDevice):
def __init__(self, target: Target): def __init__(self, target: Target):
self.target = target self.target = target
device = tvm.runtime.cpu(0) device = tvm.runtime.cpu(0)
...@@ -21,6 +20,6 @@ class CPU(TileDevice): ...@@ -21,6 +20,6 @@ class CPU(TileDevice):
__all__ = [ __all__ = [
'is_cpu_arch', "is_cpu_arch",
'CPU', "CPU",
] ]
...@@ -78,7 +78,6 @@ hopper_tensorcore_supported = ada_tensorcore_supported ...@@ -78,7 +78,6 @@ hopper_tensorcore_supported = ada_tensorcore_supported
# instead of assuming both a and b share the same dtype. # instead of assuming both a and b share the same dtype.
# As the tensorcore may supports float8_e4m3 * float8_e5m2 # As the tensorcore may supports float8_e4m3 * float8_e5m2
def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: TileDevice) -> bool: def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: TileDevice) -> bool:
if is_volta_arch(arch): if is_volta_arch(arch):
return (in_dtype, accum_dtype) in volta_tensorcore_supported return (in_dtype, accum_dtype) in volta_tensorcore_supported
elif is_ampere_arch(arch): elif is_ampere_arch(arch):
...@@ -92,7 +91,6 @@ def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: Til ...@@ -92,7 +91,6 @@ def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: Til
class TensorInstruction: class TensorInstruction:
def __init__( def __init__(
self, self,
name: str, name: str,
...@@ -104,7 +102,6 @@ class TensorInstruction: ...@@ -104,7 +102,6 @@ class TensorInstruction:
class CUDA(TileDevice): class CUDA(TileDevice):
def __init__(self, target: Target | str): def __init__(self, target: Target | str):
if isinstance(target, str): if isinstance(target, str):
target = tvm.target.Target(target) target = tvm.target.Target(target)
...@@ -148,12 +145,12 @@ class CUDA(TileDevice): ...@@ -148,12 +145,12 @@ class CUDA(TileDevice):
__all__ = [ __all__ = [
'is_cuda_arch', "is_cuda_arch",
'is_volta_arch', "is_volta_arch",
'is_ampere_arch', "is_ampere_arch",
'is_ada_arch', "is_ada_arch",
'is_hopper_arch', "is_hopper_arch",
'is_tensorcore_supported_precision', "is_tensorcore_supported_precision",
'has_mma_support', "has_mma_support",
"CUDA", "CUDA",
] ]
...@@ -83,8 +83,7 @@ def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes") ...@@ -83,8 +83,7 @@ def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes")
Get the maximum dynamic shared memory size in bytes, kilobytes, or megabytes. Get the maximum dynamic shared memory size in bytes, kilobytes, or megabytes.
""" """
assert format in ["bytes", "kb", "mb"], "Invalid format. Must be one of: bytes, kb, mb" assert format in ["bytes", "kb", "mb"], "Invalid format. Must be one of: bytes, kb, mb"
shared_mem = get_device_attribute( shared_mem = get_device_attribute(cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerMultiprocessor, device_id)
cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerMultiprocessor, device_id)
if format == "bytes": if format == "bytes":
return shared_mem return shared_mem
elif format == "kb": elif format == "kb":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment