Unverified Commit a7c9a8b9 authored by Siyuan Feng's avatar Siyuan Feng Committed by GitHub
Browse files

Refactor to support upstream tvm (#595)

**Summarize part of the rebase pr:**

1. **Support T.thread_return() → CUDA return syntax**  
   Added support for translating `T.thread_return()` to CUDA's native `return` statement.

2. **Dynamic type support for function inputs**  
   Functions now accept dynamically typed parameters using `typing`:
   ```python
   dyn_type = T.int32 or T.float
   @T.prim_func
   def main(
       a: dyn_type,
   )
   ```

3. **Device Function Codegen**  
   Added support for generating `__device__` functions in CUDA:
   ```python
   @I.ir_module
   class Module:
       @T.prim_func(private=True)
       def add(a: T.int32, b: T.int32) -> T.int32:
           return a + b

       @T.prim_func
       def main(
           A: T.Buffer((128, 128), "int32"),
           B: T.Buffer((128, 128), "int32"),
           C: T.Buffer((128, 128), "int32"),
       ):
           T.func_attr({"global_symbol": "main"})
           length: T.int32 = Module.add(64, 64)  # Host call
           for bx in...
parent 8edd6941
......@@ -3,6 +3,7 @@
from tvm import tir
from typing import Union
from tilelang.language import has_let_value, get_let_value
from tilelang.utils.language import get_buffer_region_from_load
def fill(buffer: Union[tir.Buffer, tir.BufferRegion], value: tir.PrimExpr):
......@@ -36,6 +37,12 @@ def clear(buffer: Union[tir.Buffer, tir.Var]):
buffer_region = get_let_value(buffer) # Get the actual buffer region from variable
if isinstance(buffer_region, tir.BufferRegion):
return fill(buffer_region, 0)
elif isinstance(buffer_region, tir.BufferLoad):
region = get_buffer_region_from_load(buffer_region)
if region is None:
raise ValueError(
f"Invalid buffer region: {buffer_region}, type: {type(buffer_region)}")
return fill(region, 0)
else:
raise ValueError(f"Invalid buffer region: {buffer_region}")
raise ValueError(f"Invalid buffer region: {buffer_region}, type: {type(buffer_region)}")
return fill(buffer, 0)
"""Override the LetFrame to print a message when entering the frame."""
from tvm._ffi import register_object as _register_object
from tvm.ffi import register_object as _register_object
from tvm.tir import Var, PrimExpr, BufferLoad, BufferRegion
from tvm.ir import Range
from tvm import DataType
......
......@@ -5,7 +5,7 @@ from collections import deque
from tvm import tir
from tvm.tir import Var
from tvm.script.ir_builder.tir.frame import TIRFrame, BlockFrame
from tvm._ffi import register_object
from tvm.ffi import register_object
from tilelang import _ffi_api
import threading
......
"""The language interface for tl programs."""
from tilelang import language as T
from tvm.tir import Buffer, BufferRegion
from tvm.ir import Range
from tvm.tir import Buffer, BufferRegion, BufferLoad
from tvm import tir
from typing import Union
from tilelang.utils.language import get_buffer_elems
......@@ -28,16 +27,17 @@ def any_of(buffer: Union[T.Tensor, BufferRegion]):
for i, r in enumerate(region):
extent = r.extent
if extent == 1:
new_region.append(r)
new_region.append(r.min)
else:
# check the idx is the last dimension
if i != len(region) - 1:
raise ValueError(
"Only support the last dimension to be for T.any currently, please contact us if you need this feature"
)
new_region.append(Range(r.min, 1))
buffer = BufferRegion(buffer, new_region)
return T.call_intrin(return_type, tir.op.Op.get("tl.any_of"), T.address_of(buffer), extent)
new_region.append(r.min)
buffer_load = BufferLoad(buffer, new_region)
return T.call_intrin(return_type, tir.op.Op.get("tl.any_of"), T.address_of(buffer_load),
extent)
else:
raise ValueError(f"Invalid buffer type: {type(buffer)}")
......@@ -62,15 +62,16 @@ def all_of(buffer: Union[T.Tensor, BufferRegion]):
for i, r in enumerate(region):
extent = r.extent
if extent == 1:
new_region.append(r)
new_region.append(r.min)
else:
# check the idx is the last dimension
if i != len(region) - 1:
raise ValueError(
"Only support the last dimension to be for T.any currently, please contact us if you need this feature"
)
new_region.append(Range(r.min, 1))
buffer = BufferRegion(buffer, new_region)
return T.call_intrin(return_type, tir.op.Op.get("tl.all_of"), T.address_of(buffer), extent)
new_region.append(r.min)
buffer_load = BufferLoad(buffer, new_region)
return T.call_intrin(return_type, tir.op.Op.get("tl.all_of"), T.address_of(buffer_load),
extent)
else:
raise ValueError(f"Invalid buffer type: {type(buffer)}")
from tvm._ffi.registry import register_func
from tvm.ffi.registry import register_func
from tvm.ir import make_node
......@@ -10,7 +10,7 @@ def mem_info_local_var():
tvm.ir.make_node: A node containing memory information
"""
return make_node(
"MemoryInfo",
"target.MemoryInfo",
unit_bits=8,
max_num_bits=64,
max_simd_bits=128,
......
......@@ -21,7 +21,7 @@
from typing import Type
from tvm import tir
from tvm._ffi.runtime_ctypes import DataType, DataTypeCode
from tvm.ffi.runtime_ctypes import DataType, DataTypeCode
from tvm.tir import IntImm
from tvm.tir.expr import FloatImm
......@@ -88,10 +88,10 @@ def _register_expr_op(ty: Type): # pylint: disable=invalid-name
if DataType(a.dtype).lanes == DataType(b.dtype).lanes:
return op(a, b)
elif DataType(a.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes:
elif (DataType(a.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes):
broadcast_a = tir.Broadcast(a, DataType(b.dtype).lanes)
return op(broadcast_a, b)
elif DataType(b.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes:
elif (DataType(b.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes):
broadcast_b = tir.Broadcast(b, DataType(a.dtype).lanes)
return op(a, broadcast_b)
else:
......
......@@ -8,7 +8,7 @@ from tvm.script.parser._core import parse, scan_macro, utils
def prim_func(func: Optional[Callable] = None,
private: bool = False,
check_well_formed=True) -> Union[PrimFunc, Callable]:
check_well_formed=False) -> Union[PrimFunc, Callable]:
"""The parsing method for tir prim func, by using `@prim_func` as decorator.
Parameters
......
"""The language interface for tl programs."""
from tvm.script.ir_builder.tir.frame import TIRFrame
from tvm._ffi import register_object
from tvm.ffi import register_object
from tilelang import _ffi_api
from .kernel import get_thread_bindings, get_thread_extents
from typing import List
......
......@@ -9,7 +9,7 @@ from tilelang.layout import Layout
from typing import List
@tvm._ffi.register_object("tl.Fragment")
@tvm.ffi.register_object("tl.Fragment")
class Fragment(Layout):
"""
A Fragment layout object that encapsulates iteration variables (forward_vars),
......@@ -90,7 +90,9 @@ class Fragment(Layout):
forward_thread = forward_thread_fn(*vars)
# Ensure forward_index is an array if it isn't None
if forward_index is not None and not isinstance(forward_index, tvm.ir.container.Array):
if forward_index is None:
forward_index = []
elif not isinstance(forward_index, tvm.ir.container.Array):
forward_index = [forward_index]
# Call TVM FFI constructor to set up internal data structures
......
......@@ -9,7 +9,7 @@ from typing import List
# Register the Layout class as a TVM object under the name "tl.Layout"
@tvm._ffi.register_object("tl.Layout")
@tvm.ffi.register_object("tl.Layout")
class Layout(Node):
def __init__(self, shape, forward_fn):
......
......@@ -180,7 +180,7 @@ def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
def _tir_u8_to_f8_e5m2_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
assert nbit == 8
assert dtype == "float16"
return tir.reinterpret("e5m2_float8", val).astype("float16")
return tir.reinterpret("float8_e5m2", val).astype("float16")
def _tir_packed_to_signed_convert(storage_type="uint", storage_nbit=8):
......
......@@ -87,8 +87,8 @@ def LowerHopperIntrin():
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerHopperIntrin() \
if hasattr(_ffi_api, "LowerHopperIntrin") else lambda f: f # type: ignore
return (_ffi_api.LowerHopperIntrin() if hasattr(_ffi_api, "LowerHopperIntrin") else lambda f: f
) # type: ignore
def WarpSpecializedPipeline():
......@@ -375,3 +375,32 @@ def LowerSharedBarrier():
"""LowerSharedBarrier
"""
return _ffi_api.LowerSharedBarrier() # type: ignore
def StorageRewrite():
"""StorageRewrite
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.StorageRewrite() # type: ignore
def LowerOpaqueBlock():
"""LowerOpaqueBlock
"""
return _ffi_api.LowerOpaqueBlock() # type: ignore
def LowerThreadAllreduce():
"""LowerThreadAllreduce
"""
return _ffi_api.LowerThreadAllreduce() # type: ignore
def LowerDeviceKernelLaunch():
"""LowerDeviceKernelLaunch
"""
return _ffi_api.LowerDeviceKernelLaunch() # type: ignore
"""FFI APIs for tilelang"""
import tvm._ffi
import tvm.ffi
# TVM_REGISTER_GLOBAL("tl.name").set_body_typed(func);
tvm._ffi._init_api("tl.transform", __name__) # pylint: disable=protected-access
tvm.ffi._init_api("tl.transform", __name__) # pylint: disable=protected-access
from tvm.tir import Buffer
from typing import List
from typing import List, Optional
from functools import reduce
from tvm import IRModule
from tvm.tir import PrimFunc
from tvm import ir, tir
# Scope Checkers for TVM Buffers
# These utility functions check the memory scope of a given TVM buffer.
......@@ -118,3 +119,20 @@ def retrieve_func_from_module(ir_module: IRModule) -> PrimFunc:
"The optimized module should only have one global variable for default schedule.")
func = list(ir_module.functions.values())[0]
return func
def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> Optional[tir.BufferRegion]:
"""
Get the buffer region from a buffer load.
May encounter buffer load like C[0:128, 0:32], ref to pull request
for buffer wise op: https://github.com/apache/tvm/pull/14693
convert load to region
"""
buffer, indices = buffer_load.buffer, buffer_load.indices
regions = []
for indice in indices:
if not isinstance(indice, tir.Ramp):
return None
regions.append(ir.Range.from_min_extent(indice.base, indice.lanes))
return tir.BufferRegion(buffer, regions)
......@@ -19,12 +19,12 @@ class TensorSupplyType(Enum):
def map_torch_type(intype: str) -> torch.dtype:
if intype == "e4m3_float8":
if intype == "float8_e4m3":
assert hasattr(torch, "float8_e4m3fn"), \
"torch.float8_e4m3fn is not supported in this version of torch" \
"Please upgrade torch >= 2.1.0"
return torch.float8_e4m3fn
elif intype == "e5m2_float8":
elif intype == "float8_e5m2":
assert hasattr(torch, "float8_e5m2"), \
"torch.float8_e5m2 is not supported in this version of torch" \
"Please upgrade torch >= 2.1.0"
......@@ -40,10 +40,10 @@ def map_torch_type(intype: str) -> torch.dtype:
def adapt_torch2tvm(arg):
float8_dtype_map = {
torch.float8_e4m3fn: "e4m3_float8",
torch.float8_e4m3fnuz: "e4m3_float8",
torch.float8_e5m2: "e5m2_float8",
torch.float8_e5m2fnuz: "e5m2_float8",
torch.float8_e4m3fn: "float8_e4m3",
torch.float8_e4m3fnuz: "float8_e4m3",
torch.float8_e5m2: "float8_e5m2",
torch.float8_e5m2fnuz: "float8_e5m2",
}
if isinstance(arg, torch.Tensor):
if arg.dtype in {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment