Refactor to support upstream tvm (#595)
**Summarize part of the rebase pr:**
1. **Support T.thread_return() → CUDA return syntax**
Added support for translating `T.thread_return()` to CUDA's native `return` statement.
2. **Dynamic type support for function inputs**
Functions now accept dynamically typed parameters using `typing`:
```python
dyn_type = T.int32 or T.float
@T.prim_func
def main(
a: dyn_type,
)
```
3. **Device Function Codegen**
Added support for generating `__device__` functions in CUDA:
```python
@I.ir_module
class Module:
@T.prim_func(private=True)
def add(a: T.int32, b: T.int32) -> T.int32:
return a + b
@T.prim_func
def main(
A: T.Buffer((128, 128), "int32"),
B: T.Buffer((128, 128), "int32"),
C: T.Buffer((128, 128), "int32"),
):
T.func_attr({"global_symbol": "main"})
length: T.int32 = Module.add(64, 64) # Host call
for bx in...
Showing
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Please register or sign in to comment