Unverified Commit fd6cec58 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Autotune] Add autotune coverage for symbolic M and normalize cache key (#1075)

- extend matmul autotune test suite with a symbolic M case and allow run_autotune to accept concrete values for symbolic dims
  - sanitize _kernel_parameters when generating cache keys so symbolic vars serialize deterministically
parent ba410ae3
...@@ -116,10 +116,22 @@ def matmul(M, ...@@ -116,10 +116,22 @@ def matmul(M,
return main return main
def run_autotune(M: int, N: int, K: int): def run_autotune(M, N, K, M_value=None, N_value=None, K_value=None):
import torch import torch
a = torch.randn(M, K, dtype=torch.float16).cuda()
b = torch.randn(N, K, dtype=torch.float16).cuda() def _resolve(dim, provided, name):
if isinstance(dim, T.Var):
if provided is None:
raise ValueError(f"Dynamic dimension {name} requires a concrete value.")
return provided
return dim
actual_M = _resolve(M, M_value, "M")
actual_N = _resolve(N, N_value, "N")
actual_K = _resolve(K, K_value, "K")
a = torch.randn(actual_M, actual_K, dtype=torch.float16).cuda()
b = torch.randn(actual_N, actual_K, dtype=torch.float16).cuda()
with set_autotune_inputs([a, b]): with set_autotune_inputs([a, b]):
kernel = matmul(M, N, K) kernel = matmul(M, N, K)
...@@ -140,5 +152,9 @@ def test_autotune_matmul(): ...@@ -140,5 +152,9 @@ def test_autotune_matmul():
run_autotune(1024, 1024, 1024) run_autotune(1024, 1024, 1024)
def test_autotune_matmul_symbolic_m():
run_autotune(T.symbolic("m"), 1024, 1024, M_value=1024)
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -231,6 +231,16 @@ class AutoTuner: ...@@ -231,6 +231,16 @@ class AutoTuner:
def generate_cache_key(self, parameters: Dict[str, Any]) -> Optional[AutotuneResult]: def generate_cache_key(self, parameters: Dict[str, Any]) -> Optional[AutotuneResult]:
"""Generate a cache key for the auto-tuning process. """Generate a cache key for the auto-tuning process.
""" """
def _normalize_param(value):
if isinstance(value, Var):
return str(value)
if isinstance(value, (list, tuple)):
return [_normalize_param(v) for v in value]
if isinstance(value, dict):
return {str(k): _normalize_param(v) for k, v in value.items()}
return value
# extract parameters from the function signature # extract parameters from the function signature
op_parameters = [] op_parameters = []
for _, default_value in parameters.items(): for _, default_value in parameters.items():
...@@ -238,7 +248,7 @@ class AutoTuner: ...@@ -238,7 +248,7 @@ class AutoTuner:
op_parameters.append(default_value.default) op_parameters.append(default_value.default)
if self._kernel_parameters is not None: if self._kernel_parameters is not None:
op_parameters += self._kernel_parameters op_parameters += _normalize_param(self._kernel_parameters)
func_source = inspect.getsource(self.fn) func_source = inspect.getsource(self.fn)
key_data = { key_data = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment