Refactor to support upstream tvm (#595)
**Summarize part of the rebase pr:**
1. **Support T.thread_return() → CUDA return syntax**
Added support for translating `T.thread_return()` to CUDA's native `return` statement.
2. **Dynamic type support for function inputs**
Functions now accept dynamically typed parameters using `typing`:
```python
dyn_type = T.int32 or T.float
@T.prim_func
def main(
a: dyn_type,
)
```
3. **Device Function Codegen**
Added support for generating `__device__` functions in CUDA:
```python
@I.ir_module
class Module:
@T.prim_func(private=True)
def add(a: T.int32, b: T.int32) -> T.int32:
return a + b
@T.prim_func
def main(
A: T.Buffer((128, 128), "int32"),
B: T.Buffer((128, 128), "int32"),
C: T.Buffer((128, 128), "int32"),
):
T.func_attr({"global_symbol": "main"})
length: T.int32 = Module.add(64, 64) # Host call
for bx in...
Showing
Please register or sign in to comment