"vscode:/vscode.git/clone" did not exist on "c9c70778e6050f6a105902444af57824936a2c72"
Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
......@@ -32,7 +32,6 @@ block_K = 32
def test_warp_specialized():
@T.prim_func
def before(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)):
bx = T.launch_thread("blockIdx.x", 8)
......@@ -47,25 +46,27 @@ def test_warp_specialized():
for k in T.serial(16, annotations={"num_stages": T.int32(3)}):
if v == 0:
T.tma_load(
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2,
2, 0), 0,
T.tvm_access_ptr(
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2),
k * 32, by * 64)
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0),
0,
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2),
k * 32,
by * 64,
)
if v == 0:
T.tma_load(
T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3,
2, 0), 0,
T.tvm_access_ptr(
T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2),
bx * 64, k * 32)
T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0),
0,
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2),
bx * 64,
k * 32,
)
T.call_extern(
"handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(
T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))
"handle",
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3),
)
@T.prim_func
def after(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)):
......@@ -85,34 +86,35 @@ def test_warp_specialized():
T.mbarrier_expect_tx(T.get_mbarrier(k % 3), 4096)
if v - 128 == 0:
T.tma_load(
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2,
2, 0), T.get_mbarrier(k % 3),
T.tvm_access_ptr(
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2),
k * 32, by * 64)
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0),
T.get_mbarrier(k % 3),
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2),
k * 32,
by * 64,
)
if v - 128 == 0:
T.mbarrier_expect_tx(T.get_mbarrier(k % 3), 4096)
if v - 128 == 0:
T.tma_load(
T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3,
2, 0), T.get_mbarrier(k % 3),
T.tvm_access_ptr(
T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2),
bx * 64, k * 32)
T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0),
T.get_mbarrier(k % 3),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2),
bx * 64,
k * 32,
)
T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3)]))
else:
T.set_max_nreg(240, 1)
for k in range(16):
T.mbarrier_wait_parity(T.get_mbarrier(k % 3), k // 3 % 2)
T.call_extern(
"handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(
T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))
T.evaluate(
tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)]))
"handle",
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3),
)
T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)]))
_check(before, after)
......
......@@ -6,7 +6,7 @@ from tilelang.utils.sparse import compress_sm90, randn_semi_sparse
def _test_compress_sm90(M, K, block_k, dtype):
A = randn_semi_sparse(M, K, dtype=dtype, device='cuda')
A = randn_semi_sparse(M, K, dtype=dtype, device="cuda")
A_sparse, E = compress_sm90(A, block_k, False)
......
......@@ -5,12 +5,11 @@ import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
......
......@@ -23,6 +23,7 @@ def _compute_version() -> str:
if version_file.is_file():
try:
from version_provider import dynamic_metadata # type: ignore
return dynamic_metadata("version")
except Exception:
# Fall back to the raw VERSION file if provider isn't available.
......@@ -33,6 +34,7 @@ def _compute_version() -> str:
try:
from importlib.metadata import version as _dist_version # py3.8+
return _dist_version("tilelang")
except Exception as exc:
warnings.warn(
......
from __future__ import annotations
from tvm import tir
from tvm.tir import (PyStmtExprVisitor, BufferStore, For, Var, PrimFunc, BufferLoad, IntImm)
from tvm.tir import PyStmtExprVisitor, BufferStore, For, Var, PrimFunc, BufferLoad, IntImm
from tvm.tir.transform import prim_func_pass
from tvm.tir.stmt_functor import post_order_visit
......@@ -22,14 +22,14 @@ class _LoopVarUseAnalyzer(PyStmtExprVisitor):
def collect_local_buffer_accesses(statement) -> list[BufferLoad | BufferStore]:
"""
Collect local buffer accesses in the loop body.
Collect local buffer accesses in the loop body.
Args:
statement: The TIR statement to analyze
Args:
statement: The TIR statement to analyze
Returns:
Tuple of buffer accesses in the loop body.
"""
Returns:
Tuple of buffer accesses in the loop body.
"""
buffer_accesses = []
......@@ -44,7 +44,6 @@ def collect_local_buffer_accesses(statement) -> list[BufferLoad | BufferStore]:
@tir.functor.visitor
class _FragmentLoopCheckVisitor(PyStmtExprVisitor):
def __init__(self) -> None:
super().__init__()
......@@ -75,7 +74,8 @@ class _FragmentLoopCheckVisitor(PyStmtExprVisitor):
raise ValueError(
"[Tilelang Semantic Check] "
f"Loop variable {loop.loop_var} in a T.Parallel loop with symbolic range (min={loop.min}, extent={loop.extent}) is used to index "
"a local/fragment buffer, which is not allowed in Tilelang.")
"a local/fragment buffer, which is not allowed in Tilelang."
)
return
......
......@@ -23,10 +23,7 @@ def print_fragment_format(layout: T.Fragment) -> str:
if isinstance(layout, T.Fragment):
input_shape = layout.get_input_shape()
output_shape = layout.get_output_shape()
lines = [
f" Shape: {input_shape} -> {output_shape}", f" Thread: {layout.forward_thread}",
f" Index: {layout.forward_index}"
]
lines = [f" Shape: {input_shape} -> {output_shape}", f" Thread: {layout.forward_thread}", f" Index: {layout.forward_index}"]
print("\n".join(lines))
else:
raise ValueError(f"Expected T.Fragment, but got {type(layout).__name__}")
......@@ -82,7 +79,6 @@ class _LayoutVisualVisitor(PyStmtExprVisitor):
def LayoutVisual(formats: str = ""):
def pass_fn(func: tir.PrimFunc, mod, ctx):
_LayoutVisualVisitor(formats=formats).visit_stmt(func.body)
return func
......
......@@ -11,10 +11,7 @@ from tvm.tir.transform import prim_func_pass
def is_pipelined_for(op: For) -> bool:
"""Check if a for loop is pipelined."""
anno_keys = [
"num_stages", "tl_pipeline_order", "tl_pipeline_stage", "tl_pipeline_sync",
"tl_pipeline_group"
]
anno_keys = ["num_stages", "tl_pipeline_order", "tl_pipeline_stage", "tl_pipeline_sync", "tl_pipeline_group"]
return any(key in op.annotations for key in anno_keys)
......@@ -26,7 +23,6 @@ def is_tile_op(op: Call) -> bool:
@tir.functor.visitor
class _NestedLoopCheckVisitor(PyStmtExprVisitor):
def __init__(self) -> None:
super().__init__()
self.in_parallel_context = False
......@@ -42,27 +38,24 @@ class _NestedLoopCheckVisitor(PyStmtExprVisitor):
# Otherwise
if self.in_parallel_context:
raise ValueError("[Tilelang Semantic Check] "
"Nested parallel loops are not allowed. "
"Please check your loop structure.")
raise ValueError("[Tilelang Semantic Check] Nested parallel loops are not allowed. Please check your loop structure.")
self.in_parallel_context = True
super().visit_for_(op)
self.in_parallel_context = False
return
elif is_pipelined_for(op):
if self.in_parallel_context:
raise ValueError("[Tilelang Semantic Check] "
"Pipelined loop cannot be nested inside a parallel loop. "
"Please check your loop structure.")
raise ValueError(
"[Tilelang Semantic Check] Pipelined loop cannot be nested inside a parallel loop. Please check your loop structure."
)
super().visit_for_(op)
def visit_call_(self, op: Call) -> None:
if self.in_parallel_context and is_tile_op(op):
raise ValueError("[Tilelang Semantic Check] "
"Only elementwise operations are allowed inside a parallel loop. " \
f"Got a tile-op \"{op.op}\"."
)
raise ValueError(
f'[Tilelang Semantic Check] Only elementwise operations are allowed inside a parallel loop. Got a tile-op "{op.op}".'
)
def NestedLoopChecker():
......
......@@ -85,8 +85,7 @@ def _get_current_stack() -> CaptureStack:
class AutotuneInputsCapture:
__slots__ = ("tensors")
__slots__ = "tensors"
def __init__(self, tensors: list[Any]):
self.tensors = tensors
......
"""The auto-tune parameters.
"""
"""The auto-tune parameters."""
from __future__ import annotations
import tilelang
......@@ -50,7 +50,7 @@ class CompileArgs:
out_idx: list[int] | int | None = None
execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto"
target: Literal['auto', 'cuda', 'hip'] = 'auto'
target: Literal["auto", "cuda", "hip"] = "auto"
target_host: str | Target = None
verbose: bool = False
pass_configs: dict[str, Any] | None = None
......@@ -62,24 +62,20 @@ class CompileArgs:
target=self.target,
target_host=self.target_host,
verbose=self.verbose,
pass_configs=self.pass_configs)
pass_configs=self.pass_configs,
)
def __hash__(self):
data = {
"execution_backend":
self.execution_backend,
"target":
str(self.target),
"target_host":
str(self.target_host) if self.target_host else None,
"verbose":
self.verbose,
"pass_configs":
json.dumps(self.pass_configs, sort_keys=True) if self.pass_configs else None,
"execution_backend": self.execution_backend,
"target": str(self.target),
"target_host": str(self.target_host) if self.target_host else None,
"verbose": self.verbose,
"pass_configs": json.dumps(self.pass_configs, sort_keys=True) if self.pass_configs else None,
}
hash_obj = hashlib.sha256(json.dumps(data, sort_keys=True).encode('utf-8'))
return int.from_bytes(hash_obj.digest(), byteorder='big')
hash_obj = hashlib.sha256(json.dumps(data, sort_keys=True).encode("utf-8"))
return int.from_bytes(hash_obj.digest(), byteorder="big")
@dataclass(frozen=True)
......@@ -104,6 +100,7 @@ class ProfileArgs:
manual_check_prog: Callable = None
cache_input_tensors: bool = True
"""
warmup: int = 25
rep: int = 100
timeout: int = 30
......@@ -127,8 +124,8 @@ class ProfileArgs:
"atol": self.atol,
"max_mismatched_ratio": self.max_mismatched_ratio,
}
hash_obj = hashlib.sha256(json.dumps(data, sort_keys=True).encode('utf-8'))
return int.from_bytes(hash_obj.digest(), byteorder='big')
hash_obj = hashlib.sha256(json.dumps(data, sort_keys=True).encode("utf-8"))
return int.from_bytes(hash_obj.digest(), byteorder="big")
@dataclass(frozen=True)
......@@ -143,6 +140,7 @@ class AutotuneResult:
func: Optimized function.
kernel: Compiled kernel function.
"""
latency: float | None = None
config: dict | None = None
ref_latency: float | None = None
......@@ -199,8 +197,7 @@ class AutotuneResult:
if verbose:
logger.debug(f"Saving kernel source code to file: {device_kernel_path}")
if kernel.kernel_source is not None:
self._safe_write_file(device_kernel_path, "w",
lambda f: f.write(kernel.kernel_source))
self._safe_write_file(device_kernel_path, "w", lambda f: f.write(kernel.kernel_source))
except Exception as e:
logger.error(f"Error saving kernel source code to disk: {e}")
......@@ -211,11 +208,9 @@ class AutotuneResult:
logger.debug(f"Saving wrapped kernel source code to file: {host_kernel_path}")
# Match kernel_cache behavior: use host source for tvm_ffi, otherwise wrapped kernel
if kernel.execution_backend == "tvm_ffi":
self._safe_write_file(host_kernel_path, "w",
lambda f: f.write(kernel.adapter.get_host_source()))
self._safe_write_file(host_kernel_path, "w", lambda f: f.write(kernel.adapter.get_host_source()))
else:
self._safe_write_file(host_kernel_path, "w",
lambda f: f.write(kernel.adapter.get_kernel_source()))
self._safe_write_file(host_kernel_path, "w", lambda f: f.write(kernel.adapter.get_kernel_source()))
except Exception as e:
logger.error(f"Error saving wrapped kernel source code to disk: {e}")
......@@ -237,12 +232,10 @@ class AutotuneResult:
py_src_path = src_lib_path.replace(".cubin", ".py")
if verbose:
logger.debug(f"Saving kernel nvrtc python code to file: {kernel_py_path}")
self._safe_write_file(kernel_py_path, "wb",
lambda f: f.write(self._load_binary(py_src_path)))
self._safe_write_file(kernel_py_path, "wb", lambda f: f.write(self._load_binary(py_src_path)))
if verbose:
logger.debug(f"Saving kernel library to file: {kernel_lib_path}")
self._safe_write_file(kernel_lib_path, "wb",
lambda f: f.write(self._load_binary(src_lib_path)))
self._safe_write_file(kernel_lib_path, "wb", lambda f: f.write(self._load_binary(src_lib_path)))
elif kernel.execution_backend == "tvm_ffi":
executable = kernel.adapter.executable
if verbose:
......@@ -252,8 +245,7 @@ class AutotuneResult:
src_lib_path = kernel.adapter.libpath
if verbose:
logger.debug(f"Saving kernel library to file: {kernel_lib_path}")
self._safe_write_file(kernel_lib_path, "wb",
lambda f: f.write(self._load_binary(src_lib_path)))
self._safe_write_file(kernel_lib_path, "wb", lambda f: f.write(self._load_binary(src_lib_path)))
except Exception as e:
logger.error(f"Error saving kernel library to disk: {e}")
......@@ -370,14 +362,12 @@ class AutotuneResult:
# save best config (atomic)
if verbose:
logger.debug(f"Saving best config to file: {path / BEST_CONFIG_PATH}")
self._safe_write_file(
str(path / BEST_CONFIG_PATH), "w", lambda f: json.dump(self.config, f))
self._safe_write_file(str(path / BEST_CONFIG_PATH), "w", lambda f: json.dump(self.config, f))
# save function (atomic)
if verbose:
logger.debug(f"Saving function to file: {path / FUNCTION_PATH}")
self._safe_write_file(
str(path / FUNCTION_PATH), "wb", lambda f: cloudpickle.dump(self.func, f))
self._safe_write_file(str(path / FUNCTION_PATH), "wb", lambda f: cloudpickle.dump(self.func, f))
# save ref latency (atomic)
if verbose:
......@@ -385,10 +375,13 @@ class AutotuneResult:
self._safe_write_file(
str(path / LATENCY_PATH),
"w",
lambda f: json.dump({
"latency": self.latency,
"ref_latency": self.ref_latency,
}, f),
lambda f: json.dump(
{
"latency": self.latency,
"ref_latency": self.ref_latency,
},
f,
),
)
# save kernel
......@@ -403,8 +396,8 @@ class AutotuneResult:
# Normalize target and resolve execution backend for loading
from tilelang.utils.target import determine_target as _determine_target
from tilelang.jit.execution_backend import resolve_execution_backend
norm_target = Target(_determine_target(compile_args.target)) if isinstance(
compile_args.target, str) else compile_args.target
norm_target = Target(_determine_target(compile_args.target)) if isinstance(compile_args.target, str) else compile_args.target
requested_backend = compile_args.execution_backend
resolved_backend = resolve_execution_backend(requested_backend, norm_target)
# load best config
......
......@@ -3,6 +3,7 @@
This module provides functionality for auto-tuning tilelang programs, including JIT compilation
and performance optimization through configuration search.
"""
from __future__ import annotations
from dataclasses import dataclass
......@@ -14,7 +15,8 @@ from tvm.tir import PrimFunc, Var
from tvm.target import Target
import inspect
from functools import partial
from typing import (Callable, Generic, Literal, Any, TypeVar)
from typing import Callable, Generic, Literal, Any, TypeVar
# Python 3.9 compatibility for ParamSpec
try:
from typing import ParamSpec
......@@ -74,8 +76,8 @@ def _init_logger_handlers():
global _logger_handlers_initialized
if _logger_handlers_initialized:
return
formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s')
file_handler = logging.FileHandler('autotuner.log', mode='w')
formatter = logging.Formatter("%(asctime)s %(levelname)s:%(message)s")
file_handler = logging.FileHandler("autotuner.log", mode="w")
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)
console_handler = logging.StreamHandler(sys.stdout)
......@@ -87,8 +89,7 @@ def _init_logger_handlers():
def get_available_cpu_count() -> int:
"""Gets the number of CPU cores available to the current process.
"""
"""Gets the number of CPU cores available to the current process."""
try:
cpu_count = len(os.sched_getaffinity(0))
except AttributeError:
......@@ -107,6 +108,7 @@ class AutoTuner:
fn: The function to be auto-tuned.
configs: List of configurations to try during auto-tuning.
"""
compile_args = CompileArgs()
profile_args = ProfileArgs()
......@@ -137,14 +139,15 @@ class AutoTuner:
"""
return cls(kernel, configs)
def set_compile_args(self,
out_idx: list[int] | int | None = None,
target: Literal['auto', 'cuda', 'hip', 'metal'] = 'auto',
execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc",
"torch"] = "auto",
target_host: str | Target = None,
verbose: bool = False,
pass_configs: dict[str, Any] | None = None):
def set_compile_args(
self,
out_idx: list[int] | int | None = None,
target: Literal["auto", "cuda", "hip", "metal"] = "auto",
execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto",
target_host: str | Target = None,
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
):
"""Set compilation arguments for the auto-tuner.
Args:
......@@ -161,6 +164,7 @@ class AutoTuner:
# Normalize target to a concrete TVM Target and resolve execution backend
t = Target(determine_target(target))
from tilelang.jit.execution_backend import resolve_execution_backend
resolved_backend = resolve_execution_backend(execution_backend, t)
self.compile_args = CompileArgs(
......@@ -169,23 +173,26 @@ class AutoTuner:
execution_backend=resolved_backend,
target_host=target_host,
verbose=verbose,
pass_configs=pass_configs)
pass_configs=pass_configs,
)
return self
def set_profile_args(self,
warmup: int = 25,
rep: int = 100,
timeout: int = 30,
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto,
ref_prog: Callable = None,
supply_prog: Callable = None,
rtol: float = 1e-2,
atol: float = 1e-2,
max_mismatched_ratio: float = 0.01,
skip_check: bool = False,
manual_check_prog: Callable = None,
cache_input_tensors: bool = False):
def set_profile_args(
self,
warmup: int = 25,
rep: int = 100,
timeout: int = 30,
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto,
ref_prog: Callable = None,
supply_prog: Callable = None,
rtol: float = 1e-2,
atol: float = 1e-2,
max_mismatched_ratio: float = 0.01,
skip_check: bool = False,
manual_check_prog: Callable = None,
cache_input_tensors: bool = False,
):
"""Set profiling arguments for the auto-tuner.
Args:
......@@ -209,9 +216,7 @@ class AutoTuner:
# the `supply_prog` will be ignored and the `get_autotune_inputs` will be used instead.
if get_autotune_inputs() is not None:
if supply_prog is not None:
logger.warning(
"`supply_prog` will be ignored as this program is under `with set_autotune_inputs` context."
)
logger.warning("`supply_prog` will be ignored as this program is under `with set_autotune_inputs` context.")
supply_prog = lambda _: get_autotune_inputs() # noqa: E731
self.profile_args = ProfileArgs(
......@@ -226,13 +231,13 @@ class AutoTuner:
cache_input_tensors=cache_input_tensors,
warmup=warmup,
rep=rep,
timeout=timeout)
timeout=timeout,
)
# If a custom `supply_prog` is provided, the profiler's `supply_type` setting
# becomes ineffective. The custom supply program will be used instead.
if supply_prog is not None and supply_type != tilelang.TensorSupplyType.Auto:
logger.warning("Ignoring `supply_type` passed to `set_profile_args` because "
"`supply_prog` is not None.")
logger.warning("Ignoring `supply_type` passed to `set_profile_args` because `supply_prog` is not None.")
return self
......@@ -241,10 +246,8 @@ class AutoTuner:
self._kernel_parameters = k_parameters
self._function_parameters = f_parameters
def generate_cache_key(self, parameters: dict[str, Any],
extra_parameters: dict[str, Any]) -> AutotuneResult | None:
"""Generate a cache key for the auto-tuning process.
"""
def generate_cache_key(self, parameters: dict[str, Any], extra_parameters: dict[str, Any]) -> AutotuneResult | None:
"""Generate a cache key for the auto-tuning process."""
def _normalize_param(value):
if isinstance(value, Var):
......@@ -315,8 +318,9 @@ class AutoTuner:
if var_name in parameters:
continue
# Cell content must be serializable
assert isinstance(cell.cell_contents, (int, float, str, bool, type(None))), \
assert isinstance(cell.cell_contents, (int, float, str, bool, type(None))), (
f"Cell contents {cell.cell_contents} is not serializable: {type(cell.cell_contents)}"
)
extra_parameters[var_name] = cell.cell_contents
if isinstance(self.configs, Callable):
......@@ -328,8 +332,10 @@ class AutoTuner:
if env.is_cache_enabled() and not env.is_autotune_cache_disabled():
# First check in-memory cache
if key in self._memory_cache:
logger.warning("Found kernel in memory cache. For better performance," \
" consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel.")
logger.warning(
"Found kernel in memory cache. For better performance,"
" consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel."
)
return self._memory_cache[key]
# Then check disk cache
......@@ -369,7 +375,6 @@ class AutoTuner:
# This encapsulates the logic of using either a custom supply program (`supply_prog`)
# or the default profiler input generation (`profiler._get_inputs`).
def get_input_tensors_supply(with_output: bool):
def func():
if supply_prog is not None:
return supply_prog(profiler._get_params(with_output=with_output))
......@@ -387,8 +392,7 @@ class AutoTuner:
self.jit_input_tensors = jit_input_tensors_supply()
else:
# check if the cached tensors are compatible with the current configuration
assert len(params) == len(
self.jit_input_tensors), "len(params) != len(self.jit_input_tensors)"
assert len(params) == len(self.jit_input_tensors), "len(params) != len(self.jit_input_tensors)"
for p, c in zip(params, self.jit_input_tensors):
if not isinstance(c, torch.Tensor):
# skip non-tensor inputs checking
......@@ -397,8 +401,8 @@ class AutoTuner:
# Check tensor compatibility using generator expression
def shape_equal(a, b):
return all(
a_dim == b_dim or isinstance(a_dim, Var) or isinstance(b_dim, Var)
for a_dim, b_dim in zip(a.shape, b.shape))
a_dim == b_dim or isinstance(a_dim, Var) or isinstance(b_dim, Var) for a_dim, b_dim in zip(a.shape, b.shape)
)
if p.dtype != c.dtype or not shape_equal(p, c):
logger.warning(
......@@ -409,7 +413,8 @@ class AutoTuner:
"To ensure fresh, compatible inputs are generated for every trial "
"you can disable caching by setting:\n"
" `cache_input_tensors=False`\n"
"within your `.set_compile_args(...)` call.\n")
"within your `.set_compile_args(...)` call.\n"
)
# otherwise, regenerate the input tensors for safety
self.jit_input_tensors = jit_input_tensors_supply()
break
......@@ -418,24 +423,16 @@ class AutoTuner:
if (not skip_check) and (ref_prog is not None):
if manual_check_prog is not None:
profiler.manual_assert_close(
ref_prog,
input_tensors=self.jit_input_tensors,
manual_check_prog=manual_check_prog)
profiler.manual_assert_close(ref_prog, input_tensors=self.jit_input_tensors, manual_check_prog=manual_check_prog)
else:
profiler.assert_allclose(
ref_prog,
input_tensors=self.jit_input_tensors,
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio)
latency = profiler.do_bench(
warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors)
ref_prog, input_tensors=self.jit_input_tensors, rtol=rtol, atol=atol, max_mismatched_ratio=max_mismatched_ratio
)
latency = profiler.do_bench(warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors)
if self.ref_latency_cache is None and ref_prog is not None:
self.ref_input_tensors = ref_input_tensors_supply()
self.ref_latency_cache = profiler.do_bench(
ref_prog, n_warmup=warmup, n_repeat=rep, input_tensors=self.ref_input_tensors)
self.ref_latency_cache = profiler.do_bench(ref_prog, n_warmup=warmup, n_repeat=rep, input_tensors=self.ref_input_tensors)
return latency, self.ref_latency_cache
......@@ -469,17 +466,14 @@ class AutoTuner:
# Check if all tunable arguments have been tuned by comparing config keys with key_kwargs_tuple
if any(key in top_config for key, _ in key_kwargs_tuple) or any(
check_tunable_argument_value(key, self._function_parameters, key_args_tuple)
for key in tunable_arguments):
check_tunable_argument_value(key, self._function_parameters, key_args_tuple) for key in tunable_arguments
):
logger.warning(
f"Tunable parameters {tunable_arguments} already provided during auto-tuning. Skipping compilation and using direct JIT"
)
# compile the kernel with the provided parameters
jit_kernel = self.jit_compile()
autotuner_result = AutotuneResult(
libcode=jit_kernel.get_kernel_source(),
func=jit_kernel.prim_func,
kernel=jit_kernel)
autotuner_result = AutotuneResult(libcode=jit_kernel.get_kernel_source(), func=jit_kernel.prim_func, kernel=jit_kernel)
self._memory_cache[key] = autotuner_result
return autotuner_result
# get the cpu count
......@@ -489,9 +483,7 @@ class AutoTuner:
max_cpu_count = int(env.TILELANG_AUTO_TUNING_MAX_CPU_COUNT)
if cpu_counts > 0:
num_workers = min(cpu_counts, available_cpu_count)
logger.info(
f"Auto-tuning with {cpu_counts} CPU counts, {available_cpu_count} CPUs available, {num_workers} CPUs will be used"
)
logger.info(f"Auto-tuning with {cpu_counts} CPU counts, {available_cpu_count} CPUs available, {num_workers} CPUs will be used")
else:
num_workers = max(1, int(available_cpu_count * cpu_utilizations))
logger.info(
......@@ -509,7 +501,6 @@ class AutoTuner:
future_to_index = {}
def cuda_device_wrapper(func, device):
def inner(**config_arg):
torch.cuda.set_device(device)
return func(**config_arg)
......@@ -532,18 +523,14 @@ class AutoTuner:
future_to_index[future] = i
results_with_configs = []
for future in tqdm(
concurrent.futures.as_completed(futures),
total=len(futures),
desc="Compiling configurations"):
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Compiling configurations"):
idx = future_to_index[future]
config = config_args[idx]
try:
result = future.result()
results_with_configs.append((result, config))
except Exception as e:
logger.debug(
f"Compilation failed for config {config} at index {idx} with error: {e}")
logger.debug(f"Compilation failed for config {config} at index {idx} with error: {e}")
continue
ref_latency = None
......@@ -556,14 +543,10 @@ class AutoTuner:
# latency, ref_latency = target_fn(jit_kernel)
latency, ref_latency = run_with_timeout(target_fn, timeout, jit_kernel)
except TimeoutException:
logger.warning(
f"A timeout occurred while testing config {config}, checkout autotuner.log for more details"
)
logger.warning(f"A timeout occurred while testing config {config}, checkout autotuner.log for more details")
continue
except Exception:
logger.warning(
f"An error occurred while testing config {config}, checkout autotuner.log for more details"
)
logger.warning(f"An error occurred while testing config {config}, checkout autotuner.log for more details")
logger.debug(f"Error: {traceback.format_exc()}")
continue
......@@ -578,8 +561,7 @@ class AutoTuner:
pool.shutdown()
if best_kernel is None:
error_msg = ("Auto-tuning failed: No configuration successfully "
"compiled and passed benchmarking/validation.")
error_msg = "Auto-tuning failed: No configuration successfully compiled and passed benchmarking/validation."
logger.error(error_msg)
raise RuntimeError(error_msg)
......@@ -595,7 +577,8 @@ class AutoTuner:
ref_latency=ref_latency,
libcode=best_kernel.get_kernel_source(),
func=best_kernel.prim_func,
kernel=best_kernel)
kernel=best_kernel,
)
if self.compile_args.execution_backend in ("torch"):
logger.warning("DLPack backend does not support cache saving to disk.")
......@@ -617,8 +600,8 @@ class AutoTuner:
return self.run()
_P = ParamSpec('_P')
_T = TypeVar('_T')
_P = ParamSpec("_P")
_T = TypeVar("_T")
@dataclass
......@@ -643,8 +626,9 @@ class AutoTuneImpl(Generic[_P, _T]):
self._tuner_cache = {}
def get_tunner(self):
autotuner = AutoTuner(
self.jit_impl.func, configs=self.configs).set_profile_args(
autotuner = (
AutoTuner(self.jit_impl.func, configs=self.configs)
.set_profile_args(
supply_type=self.supply_type,
ref_prog=self.ref_prog,
supply_prog=self.supply_prog,
......@@ -654,7 +638,8 @@ class AutoTuneImpl(Generic[_P, _T]):
skip_check=self.skip_check,
manual_check_prog=self.manual_check_prog,
cache_input_tensors=self.cache_input_tensors,
).set_compile_args(
)
.set_compile_args(
out_idx=self.jit_impl.out_idx,
execution_backend=self.jit_impl.execution_backend,
target=self.jit_impl.target,
......@@ -662,6 +647,7 @@ class AutoTuneImpl(Generic[_P, _T]):
verbose=self.jit_impl.verbose,
pass_configs=self.jit_impl.pass_configs,
)
)
autotuner.run = partial(autotuner.run, self.warmup, self.rep, self.timeout)
return autotuner
......@@ -753,16 +739,13 @@ def autotune( # This is the new public interface
if callable(func):
# Case 1: Used as @autotune (func_or_out_idx is the function, others are defaults)
# This is a placeholder for a real auto tuner implementation
raise ValueError(
"Use tilelang.autotune to decorate func without arguments is not supported yet.")
raise ValueError("Use tilelang.autotune to decorate func without arguments is not supported yet.")
elif isinstance(func, PrimFunc):
raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.")
else:
def decorator(impl):
assert isinstance(
impl, JITImpl
), "The @autotune decorator can only be applied to @tilelang.jit decorated instances."
assert isinstance(impl, JITImpl), "The @autotune decorator can only be applied to @tilelang.jit decorated instances."
return AutoTuneImpl(
jit_impl=impl,
configs=configs,
......
"""The cache utils with class and database persistence - Init file"""
from __future__ import annotations
from typing import Literal
......@@ -18,8 +19,7 @@ def cached(
*args,
target: str | Target = "auto",
target_host: str | Target = None,
execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"]
| None = "auto",
execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] | None = "auto",
verbose: bool | None = False,
pass_configs: dict | None = None,
compile_flags: list[str] | str | None = None,
......@@ -36,7 +36,8 @@ def cached(
execution_backend=execution_backend,
verbose=verbose,
pass_configs=pass_configs,
compile_flags=compile_flags)
compile_flags=compile_flags,
)
def clear_cache():
......@@ -47,9 +48,11 @@ def clear_cache():
RuntimeError: Always raised to warn users to clear the cache manually.
"""
cache_dir = env.TILELANG_CACHE_DIR
raise RuntimeError("tilelang.clear_cache() is disabled because deleting the cache directory "
"is dangerous. If you accept the risk, remove it manually with "
f"`rm -rf '{cache_dir}'`.")
raise RuntimeError(
"tilelang.clear_cache() is disabled because deleting the cache directory "
"is dangerous. If you accept the risk, remove it manually with "
f"`rm -rf '{cache_dir}'`."
)
if env.TILELANG_CLEAR_CACHE.lower() in ("1", "true", "yes", "on"):
......
"""The cache utils with class and database persistence - KernelCache Class"""
from __future__ import annotations
import json
......@@ -97,9 +98,7 @@ class KernelCache:
"version": __version__,
"func": sha256(func_binary).hexdigest(), # Use SHA256 to generate hash key
"out_idx": (tuple(out_idx) if isinstance(out_idx, (list, tuple)) else [out_idx]),
"args_repr": tuple(
repr(arg) for arg in args
), # Use repr to serialize arguments, may need more robust serialization
"args_repr": tuple(repr(arg) for arg in args), # Use repr to serialize arguments, may need more robust serialization
"target": str(target),
"target_host": str(target_host) if target_host else None,
"execution_backend": execution_backend,
......@@ -118,8 +117,7 @@ class KernelCache:
*args,
target: str | Target = "auto",
target_host: str | Target = None,
execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc",
"torch"] = "auto",
execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto",
verbose: bool = False,
pass_configs: dict = None,
compile_flags: list[str] | str | None = None,
......@@ -140,6 +138,7 @@ class KernelCache:
# Normalize target and resolve execution backend before proceeding
from tilelang.utils.target import determine_target as _determine_target
from tilelang.jit.execution_backend import resolve_execution_backend, allowed_backends_for_target
norm_target = Target(_determine_target(target)) if isinstance(target, str) else target
requested_backend = execution_backend
execution_backend = resolve_execution_backend(requested_backend, norm_target)
......@@ -180,21 +179,21 @@ class KernelCache:
with self._lock:
# First check in-memory cache
if key in self._memory_cache:
self.logger.warning("Found kernel in memory cache. For better performance," \
" consider using `@tilelang.jit` instead of direct kernel caching.")
self.logger.warning(
"Found kernel in memory cache. For better performance, consider using `@tilelang.jit` instead of direct kernel caching."
)
return self._memory_cache[key]
if verbose:
self.logger.debug(f"Checking disk cache for kernel {func.attrs['global_symbol']}")
# Then check disk cache
kernel = self._load_kernel_from_disk(key, norm_target, target_host, out_idx,
execution_backend, pass_configs, compile_flags,
func, verbose)
kernel = self._load_kernel_from_disk(
key, norm_target, target_host, out_idx, execution_backend, pass_configs, compile_flags, func, verbose
)
if kernel is not None:
if verbose:
self.logger.debug(
f"Found kernel in disk cache for {func.attrs['global_symbol']}")
self.logger.debug(f"Found kernel in disk cache for {func.attrs['global_symbol']}")
# Populate memory cache with disk result
self._memory_cache[key] = kernel
return kernel
......@@ -262,11 +261,7 @@ class KernelCache:
executable.export_library(temp_path)
os.replace(temp_path, path)
def _save_kernel_to_disk(self,
key: str,
kernel: JITKernel,
func: Callable = None,
verbose: bool = False):
def _save_kernel_to_disk(self, key: str, kernel: JITKernel, func: Callable = None, verbose: bool = False):
"""
Persists a compiled kernel to disk cache.
......@@ -292,8 +287,7 @@ class KernelCache:
if verbose:
self.logger.debug(f"Saving kernel source code to file: {device_kernel_path}")
if kernel.kernel_source is not None:
KernelCache._safe_write_file(device_kernel_path, "w",
lambda file: file.write(kernel.kernel_source))
KernelCache._safe_write_file(device_kernel_path, "w", lambda file: file.write(kernel.kernel_source))
except Exception as e:
self.logger.error(f"Error saving kernel source code to disk: {e}")
......@@ -303,13 +297,9 @@ class KernelCache:
if verbose:
self.logger.debug(f"Saving wrapped kernel source code to file: {host_kernel_path}")
if self.execution_backend == "tvm_ffi":
KernelCache._safe_write_file(
host_kernel_path, "w",
lambda file: file.write(kernel.adapter.get_host_source()))
KernelCache._safe_write_file(host_kernel_path, "w", lambda file: file.write(kernel.adapter.get_host_source()))
else:
KernelCache._safe_write_file(
host_kernel_path, "w",
lambda file: file.write(kernel.adapter.get_kernel_source()))
KernelCache._safe_write_file(host_kernel_path, "w", lambda file: file.write(kernel.adapter.get_kernel_source()))
except Exception as e:
self.logger.error(f"Error saving host kernel source code to disk: {e}")
......@@ -332,9 +322,7 @@ class KernelCache:
src_lib_path = src_lib_path.replace(".cubin", ".py")
if verbose:
self.logger.debug(f"Saving kernel nvrtc python code to file: {kernel_py_path}")
KernelCache._safe_write_file(
kernel_py_path, "wb",
lambda file: file.write(KernelCache._load_binary(src_lib_path)))
KernelCache._safe_write_file(kernel_py_path, "wb", lambda file: file.write(KernelCache._load_binary(src_lib_path)))
elif self.execution_backend == "tvm_ffi":
executable = kernel.adapter.executable
if verbose:
......@@ -344,9 +332,7 @@ class KernelCache:
src_lib_path = kernel.adapter.libpath
if verbose:
self.logger.debug(f"Saving kernel library to file: {kernel_lib_path}")
KernelCache._safe_write_file(
kernel_lib_path, "wb",
lambda file: file.write(KernelCache._load_binary(src_lib_path)))
KernelCache._safe_write_file(kernel_lib_path, "wb", lambda file: file.write(KernelCache._load_binary(src_lib_path)))
except Exception as e:
self.logger.error(f"Error saving kernel library to disk: {e}")
......@@ -356,8 +342,7 @@ class KernelCache:
params_path = os.path.join(cache_path, PARAMS_PATH)
if verbose:
self.logger.debug(f"Saving kernel parameters to disk: {params_path}")
KernelCache._safe_write_file(params_path, "wb",
lambda file: cloudpickle.dump(kernel.params, file))
KernelCache._safe_write_file(params_path, "wb", lambda file: cloudpickle.dump(kernel.params, file))
except Exception as e:
self.logger.error(f"Error saving kernel parameters to disk: {e}")
......@@ -417,8 +402,7 @@ class KernelCache:
self.logger.error(f"Error loading kernel source code from disk: {e}")
try:
if verbose:
self.logger.debug(
f"Loading wrapped kernel source code from file: {host_kernel_path}")
self.logger.debug(f"Loading wrapped kernel source code from file: {host_kernel_path}")
with open(host_kernel_path) as f:
host_kernel_source = f.read()
except Exception as e:
......
"""Base infra"""
from .analysis import (
BlockInfo, # noqa: F401
IterInfo, # noqa: F401
......
"""Analysis on TIR blocks, loops and functions."""
from __future__ import annotations
from typing_extensions import Literal
......@@ -144,11 +145,13 @@ def normalize_prim_func(sch: tir.Schedule) -> list[BlockInfo] | None:
var=iter.var,
dom=iter.dom,
loop_rv=loop,
) for loop, iter in zip(loops, iters)
)
for loop, iter in zip(loops, iters)
],
block_rv=block,
reduction_block=is_reduction,
))
)
)
return blocks
......@@ -188,8 +191,7 @@ def get_max_shared_memory_per_block(target: Target) -> int:
_assert_gpu_target(target)
max_shared_memory_per_block = target.attrs.get("max_shared_memory_per_block", None)
if max_shared_memory_per_block is None:
raise ValueError(
f"Cannot find `max_shared_memory_per_block` in {target}, please specify it manually")
raise ValueError(f"Cannot find `max_shared_memory_per_block` in {target}, please specify it manually")
return int(max_shared_memory_per_block)
......@@ -197,13 +199,11 @@ def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV:
try:
block = sch.mod[func_name].body.block
except Exception:
raise ValueError(f"The function body is expected to be the root block, but got:\n"
f"{sch.mod[func_name].body}") from None
raise ValueError(f"The function body is expected to be the root block, but got:\n{sch.mod[func_name].body}") from None
return sch.get_block(block.name_hint)
def collect_block_iter_vars_used_in_access_region(block: tir.Block,
region: list[ir.Range]) -> set[tir.Var]:
def collect_block_iter_vars_used_in_access_region(block: tir.Block, region: list[ir.Range]) -> set[tir.Var]:
"""Collect the block iter variables used in the access region of a buffer region."""
tir_vars = set()
for expr in region:
......@@ -251,15 +251,13 @@ def is_broadcast_epilogue(
for buffer_region in sch.get(epilogue).reads:
if buffer_region.buffer not in write_buffers:
continue
tir_vars = collect_block_iter_vars_used_in_access_region(
sch.get(epilogue), buffer_region.region)
tir_vars = collect_block_iter_vars_used_in_access_region(sch.get(epilogue), buffer_region.region)
if len(tir_vars) < len(epilogue_iters):
return True
return False
def get_reduction_blocks(sch: tir.Schedule,
blocks: list[tir.schedule.BlockRV]) -> list[tir.schedule.BlockRV]:
def get_reduction_blocks(sch: tir.Schedule, blocks: list[tir.schedule.BlockRV]) -> list[tir.schedule.BlockRV]:
# Get the main computation block
def is_reduction(block: BlockRV) -> bool:
block_stmt = sch.get(block)
......
......@@ -39,18 +39,18 @@ def auto_infer_current_arch() -> TileDevice:
__all__ = [
'is_cpu_arch',
'is_cuda_arch',
'is_volta_arch',
'is_ampere_arch',
'is_ada_arch',
'is_hopper_arch',
'is_tensorcore_supported_precision',
'has_mma_support',
'is_cdna_arch',
'is_metal_arch',
'CUDA',
'CDNA',
'METAL',
'CPU',
"is_cpu_arch",
"is_cuda_arch",
"is_volta_arch",
"is_ampere_arch",
"is_ada_arch",
"is_hopper_arch",
"is_tensorcore_supported_precision",
"has_mma_support",
"is_cdna_arch",
"is_metal_arch",
"CUDA",
"CDNA",
"METAL",
"CPU",
]
......@@ -7,9 +7,7 @@ class TileDevice:
self.reg_cap: int = 0 # Register capacity: The amount of register memory available
self.smem_cap: int = 0 # Shared memory capacity: The amount of shared memory available
self.compute_max_core: int = 0 # The maximum number of computing cores
self.warp_size: int = (
0 # The size of a warp, a group of threads that execute instructions in lockstep
)
self.warp_size: int = 0 # The size of a warp, a group of threads that execute instructions in lockstep
self.sm_partition: int = 0 # The number of streaming multiprocessor partitions
self.transaction_size: list[int] = [
0,
......@@ -21,9 +19,7 @@ class TileDevice:
0,
] # Bandwidth specifications, possibly including peak and sustained rates
self.platform: str = "unknown" # The platform or manufacturer of the device
self.compute_capability: str = (
"unknown" # The compute capability, indicating the feature set and performance level
)
self.compute_capability: str = "unknown" # The compute capability, indicating the feature set and performance level
self.l2_cache_size_bytes: int = 0
# the number of transaction size in bytes
self.transaction_size: list[int] = [0, 0] # in bytes
......
......@@ -9,7 +9,6 @@ def is_cdna_arch(arch: TileDevice) -> bool:
class CDNA(TileDevice):
def __init__(self, target: Target | str):
if isinstance(target, str):
target = tvm.target.Target(target)
......@@ -33,6 +32,6 @@ class CDNA(TileDevice):
__all__ = [
'is_cdna_arch',
'CDNA',
"is_cdna_arch",
"CDNA",
]
......@@ -10,7 +10,6 @@ def is_cpu_arch(arch: TileDevice) -> bool:
# For LLVM Backend, we do not provide the detailed information of the CPU
# As the LLVM backend do not required tuning, just maintain the consistency
class CPU(TileDevice):
def __init__(self, target: Target):
self.target = target
device = tvm.runtime.cpu(0)
......@@ -21,6 +20,6 @@ class CPU(TileDevice):
__all__ = [
'is_cpu_arch',
'CPU',
"is_cpu_arch",
"CPU",
]
......@@ -78,7 +78,6 @@ hopper_tensorcore_supported = ada_tensorcore_supported
# instead of assuming both a and b share the same dtype.
# As the tensorcore may supports float8_e4m3 * float8_e5m2
def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: TileDevice) -> bool:
if is_volta_arch(arch):
return (in_dtype, accum_dtype) in volta_tensorcore_supported
elif is_ampere_arch(arch):
......@@ -92,7 +91,6 @@ def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: Til
class TensorInstruction:
def __init__(
self,
name: str,
......@@ -104,7 +102,6 @@ class TensorInstruction:
class CUDA(TileDevice):
def __init__(self, target: Target | str):
if isinstance(target, str):
target = tvm.target.Target(target)
......@@ -148,12 +145,12 @@ class CUDA(TileDevice):
__all__ = [
'is_cuda_arch',
'is_volta_arch',
'is_ampere_arch',
'is_ada_arch',
'is_hopper_arch',
'is_tensorcore_supported_precision',
'has_mma_support',
"is_cuda_arch",
"is_volta_arch",
"is_ampere_arch",
"is_ada_arch",
"is_hopper_arch",
"is_tensorcore_supported_precision",
"has_mma_support",
"CUDA",
]
......@@ -83,8 +83,7 @@ def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes")
Get the maximum dynamic shared memory size in bytes, kilobytes, or megabytes.
"""
assert format in ["bytes", "kb", "mb"], "Invalid format. Must be one of: bytes, kb, mb"
shared_mem = get_device_attribute(
cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerMultiprocessor, device_id)
shared_mem = get_device_attribute(cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerMultiprocessor, device_id)
if format == "bytes":
return shared_mem
elif format == "kb":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment