Unverified Commit fd6cec58 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Autotune] Add autotune coverage for symbolic M and normalize cache key (#1075)

- extend matmul autotune test suite with a symbolic M case and allow run_autotune to accept concrete values for symbolic dims
  - sanitize _kernel_parameters when generating cache keys so symbolic vars serialize deterministically
parent ba410ae3
......@@ -116,10 +116,22 @@ def matmul(M,
return main
def run_autotune(M: int, N: int, K: int):
def run_autotune(M, N, K, M_value=None, N_value=None, K_value=None):
import torch
a = torch.randn(M, K, dtype=torch.float16).cuda()
b = torch.randn(N, K, dtype=torch.float16).cuda()
def _resolve(dim, provided, name):
if isinstance(dim, T.Var):
if provided is None:
raise ValueError(f"Dynamic dimension {name} requires a concrete value.")
return provided
return dim
actual_M = _resolve(M, M_value, "M")
actual_N = _resolve(N, N_value, "N")
actual_K = _resolve(K, K_value, "K")
a = torch.randn(actual_M, actual_K, dtype=torch.float16).cuda()
b = torch.randn(actual_N, actual_K, dtype=torch.float16).cuda()
with set_autotune_inputs([a, b]):
kernel = matmul(M, N, K)
......@@ -140,5 +152,9 @@ def test_autotune_matmul():
run_autotune(1024, 1024, 1024)
def test_autotune_matmul_symbolic_m():
run_autotune(T.symbolic("m"), 1024, 1024, M_value=1024)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -231,6 +231,16 @@ class AutoTuner:
def generate_cache_key(self, parameters: Dict[str, Any]) -> Optional[AutotuneResult]:
"""Generate a cache key for the auto-tuning process.
"""
def _normalize_param(value):
if isinstance(value, Var):
return str(value)
if isinstance(value, (list, tuple)):
return [_normalize_param(v) for v in value]
if isinstance(value, dict):
return {str(k): _normalize_param(v) for k, v in value.items()}
return value
# extract parameters from the function signature
op_parameters = []
for _, default_value in parameters.items():
......@@ -238,7 +248,7 @@ class AutoTuner:
op_parameters.append(default_value.default)
if self._kernel_parameters is not None:
op_parameters += self._kernel_parameters
op_parameters += _normalize_param(self._kernel_parameters)
func_source = inspect.getsource(self.fn)
key_data = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment