Commit 0b521a3b authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] support composable expression for shape with symbolic vars (#624)

* [Refactor] Enhance expression handling in utils.py and update wrapper to use pythonic_expr

- Added support for additional TIR expressions (FloorDiv, Min, Max, Add, Sub, FloorMod) in the pythonic_expr function to improve string representation.
- Replaced the deprecated legalize_c function calls in TLCUDASourceWrapper and TLCPUSourceWrapper with pythonic_expr for better expression handling in kernel launch code.

* [Refactor] Simplify expression handling in pythonic_expr function

- Consolidated binary and min/max operation handling in the pythonic_expr function to improve readability and maintainability.
- Replaced individual checks for binary operations with a mapping approach, streamlining the code and enhancing performance in expression representation.

* [Enhancement] Improve expression representation in pythonic_expr function

- Added operator precedence handling to the pythonic_expr function, enhancing the conversion of TVM PrimExpr to Python-style strings.
- Updated the visitor logic to intelligently add parentheses based on operator precedence, improving the accuracy of expression representation.
- Included a docstring for better clarity on the function's purpose and usage.

* test fix

* minor update
parent 22aed721
...@@ -4,7 +4,7 @@ from dataclasses import dataclass ...@@ -4,7 +4,7 @@ from dataclasses import dataclass
from typing import List, Union, Optional from typing import List, Union, Optional
import torch import torch
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.tir import Buffer, IntImm, Var from tvm.tir import Buffer, IntImm, Var, PrimExpr
from tilelang.utils.tensor import map_torch_type from tilelang.utils.tensor import map_torch_type
...@@ -36,10 +36,10 @@ class KernelParam: ...@@ -36,10 +36,10 @@ class KernelParam:
for s in buffer.shape: for s in buffer.shape:
if isinstance(s, IntImm): if isinstance(s, IntImm):
shape.append(s.value) shape.append(s.value)
elif isinstance(s, Var): elif isinstance(s, (Var, PrimExpr)):
shape.append(s) shape.append(s)
else: else:
raise ValueError(f"Unsupported dimension type: {type(s)}") raise ValueError(f"Unsupported dimension type: {type(s)} {s}")
return cls(dtype, shape) return cls(dtype, shape)
@classmethod @classmethod
......
...@@ -145,6 +145,12 @@ cdef class CythonKernelWrapper: ...@@ -145,6 +145,12 @@ cdef class CythonKernelWrapper:
else: # Already converted to Python int during initialization else: # Already converted to Python int during initialization
shape.append(s) shape.append(s)
device = inputs[0].device if len(inputs) > 0 else torch.cuda.current_device() device = inputs[0].device if len(inputs) > 0 else torch.cuda.current_device()
if len(shape) == 0:
param_name = self.params[i].name if hasattr(self.params[i], 'name') else f'parameter_{i}'
raise ValueError(
f"Cannot create output tensor (name={param_name}) - 0-dimensional tensors are not supported. "
f"Expected shape: {shape}"
)
tensor = torch.empty(*shape, dtype=dtype, device=device) tensor = torch.empty(*shape, dtype=dtype, device=device)
else: else:
tensor = inputs[ins_idx] tensor = inputs[ins_idx]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment