Unverified Commit 777881e1 authored by Kurisu's avatar Kurisu Committed by GitHub
Browse files

[Feat] Add support for `T.serial` with step and negative step (#1188)



* [Feature] Support serial for with step

* add more tests

* fix

* Enhance trip count validation in SerialForWithStep to ensure non-zero step values and prevent undefined behavior. Added error handling for zero step values and improved logging for non-constant steps.

* Update builder.py

* fix lint error

---------
Co-authored-by: default avatarZhiwen Mo <zm125@ic.ac.uk>
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent a59d41d6
...@@ -3,6 +3,8 @@ import tilelang.language as T ...@@ -3,6 +3,8 @@ import tilelang.language as T
import torch import torch
import tilelang.testing import tilelang.testing
import tvm import tvm
from tvm.script.ir_builder.base import IRBuilderFrame
from tvm.tir.expr import IntImm, Var
def test_argument(): def test_argument():
...@@ -273,6 +275,43 @@ def test_prim_func_generator(): ...@@ -273,6 +275,43 @@ def test_prim_func_generator():
assert isinstance(foo, T.PrimFunc) assert isinstance(foo, T.PrimFunc)
def test_serial_for_with_step():
@tilelang.jit(out_idx=-1)
@T.prim_func
def test_stepped_serial(A: T.Tensor((10,), T.int32)):
with T.Kernel(1) as _:
for i in range(0, 10, 2):
T.device_assert(0 <= i < 10 and i % 2 == 0, "i out of range")
A[i] = 1.0
for i in range(1, 10, 2):
T.device_assert(1 <= i < 10 and i % 2 == 1, "i out of range")
A[i] = 2.0
ker = test_stepped_serial()
res = ker()
ref = torch.tensor([1, 2, 1, 2, 1, 2, 1, 2, 1, 2], dtype=torch.int32, device='cuda')
assert torch.all(res == ref), f"Expected {ref}, but got {res}"
@tilelang.jit(out_idx=-1)
@T.prim_func
def test_serial_step_neg(A: T.Tensor((10,), T.int32)):
with T.Kernel(1) as _:
for i in range(10, 0, -1):
T.device_assert(0 < i <= 10, "i out of range")
A[10 - i] = i
ker = test_serial_step_neg()
res = ker()
ref = torch.tensor([10, 9, 8, 7, 6, 5, 4, 3, 2, 1], dtype=torch.int32, device='cuda')
assert torch.all(res == ref), f"Expected {ref}, but got {res}"
assert isinstance(T.serial(1, 10, 1), IRBuilderFrame)
assert isinstance(T.serial(1, 10, IntImm('int32', 1)), IRBuilderFrame)
assert not isinstance(T.serial(1, 10, Var('tmp', 'int32')), IRBuilderFrame)
assert not isinstance(T.serial(10, -1, -1), IRBuilderFrame)
def test_swap_logic(): def test_swap_logic():
@tilelang.jit @tilelang.jit
......
...@@ -23,9 +23,7 @@ from .proxy import ( ...@@ -23,9 +23,7 @@ from .proxy import (
SharedBuffer, # noqa: F401 SharedBuffer, # noqa: F401
LocalBuffer, # noqa: F401 LocalBuffer, # noqa: F401
) )
from .parallel import Parallel # noqa: F401 from .loop import serial, Parallel, Persistent, Pipelined # noqa: F401
from .pipeline import Pipelined # noqa: F401
from .persistent import Persistent # noqa: F401
from .frame import has_let_value, get_let_value # noqa: F401 from .frame import has_let_value, get_let_value # noqa: F401
from .math_intrinsics import * # noqa: F401 from .math_intrinsics import * # noqa: F401
from .kernel import ( from .kernel import (
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations from __future__ import annotations
from typing import Any
from tvm import tir from tvm import tir
from tvm.tir import IntImm from tvm.tir import IntImm
import tvm.script.ir_builder.tir as tb_tir
from .v2.builder import SerialForWithStep
from tilelang import _ffi_api from tilelang import _ffi_api
def Parallel(*extents: tir.PrimExpr, coalesced_width: int | None = None):
"""Tools to construct nested parallel for loop.
This can be used to create element-wise tensor expression.
Parameters
----------
extents : PrimExpr
The extents of the iteration.
coalesced_width : Optional[int]
The coalesced width of the parallel loop.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
annotations: dict[str, Any] = {}
if coalesced_width is not None:
annotations.update({"coalesced_width": coalesced_width})
return _ffi_api.Parallel(extents, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
def Persistent(
domain: list[tir.PrimExpr],
wave_size: tir.PrimExpr,
index: tir.PrimExpr,
group_size: tir.PrimExpr | None = 8,
):
"""Tools to construct persistent for loop.
Parameters
----------
domain : List[tir.PrimExpr]
The list of dominators.
wave_size : int
The wave size.
index : int
The tile index in one wave.
group_size : tir.PrimExpr
The group size.
"""
return _ffi_api.Persistent(domain, wave_size, index, group_size)
def Pipelined( def Pipelined(
start: tir.PrimExpr, start: tir.PrimExpr,
stop: tir.PrimExpr = None, stop: tir.PrimExpr = None,
...@@ -44,3 +92,20 @@ def Pipelined( ...@@ -44,3 +92,20 @@ def Pipelined(
group = [] group = []
# type: ignore[attr-defined] # pylint: disable=no-member # type: ignore[attr-defined] # pylint: disable=no-member
return _ffi_api.Pipelined(start, stop, num_stages, order, stage, sync, group) return _ffi_api.Pipelined(start, stop, num_stages, order, stage, sync, group)
def serial(start: tir.PrimExpr,
stop: tir.PrimExpr | None = None,
step: tir.PrimExpr | None = None,
*,
annotations: dict[str, Any] | None = None):
step_is_one = False
step_is_one |= isinstance(step, int) and step == 1
step_is_one |= isinstance(step, IntImm) and step.value == 1
if step is None or step_is_one:
return tb_tir.serial(start, stop, annotations=annotations)
else:
if stop is None:
stop = start
start = IntImm(start.dtype, 0) if hasattr(start, "dtype") else 0
return SerialForWithStep(start, stop, step, annotations=annotations)
"""The language interface for tl programs."""
from __future__ import annotations
from typing import Any
from tvm import tir
from tilelang import _ffi_api
def Parallel(*extents: tir.PrimExpr, coalesced_width: int | None = None):
"""Tools to construct nested parallel for loop.
This can be used to create element-wise tensor expression.
Parameters
----------
extents : PrimExpr
The extents of the iteration.
coalesced_width : Optional[int]
The coalesced width of the parallel loop.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
annotations: dict[str, Any] = {}
if coalesced_width is not None:
annotations.update({"coalesced_width": coalesced_width})
return _ffi_api.Parallel(extents, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
"""The language interface for tl programs."""
from __future__ import annotations
from tvm import tir
from tilelang import _ffi_api
def Persistent(
domain: list[tir.PrimExpr],
wave_size: tir.PrimExpr,
index: tir.PrimExpr,
group_size: tir.PrimExpr | None = 8,
):
"""Tools to construct persistent for loop.
Parameters
----------
domain : List[tir.PrimExpr]
The list of dominators.
wave_size : int
The wave size.
index : int
The tile index in one wave.
group_size : tir.PrimExpr
The group size.
"""
return _ffi_api.Persistent(domain, wave_size, index, group_size)
...@@ -100,6 +100,14 @@ class BreakFrame(Frame): ...@@ -100,6 +100,14 @@ class BreakFrame(Frame):
... ...
@dataclass
class SerialForWithStep:
start: PrimExpr
stop: PrimExpr
step: PrimExpr
annotations: dict[str, Any] | None = None
# Python 3.9 compatibility: avoid PEP 604 unions at runtime # Python 3.9 compatibility: avoid PEP 604 unions at runtime
# Use tuple for isinstance checks and typing.Union for annotations/aliases # Use tuple for isinstance checks and typing.Union for annotations/aliases
ContinueOrBreak = (ContinueFrame, BreakFrame) ContinueOrBreak = (ContinueFrame, BreakFrame)
...@@ -243,12 +251,32 @@ class Builder(BaseBuilder): ...@@ -243,12 +251,32 @@ class Builder(BaseBuilder):
def ctx_for(self, it): def ctx_for(self, it):
self.check_continue_break() self.check_continue_break()
it = unwrap_expr(it) it = unwrap_expr(it)
if not isinstance(it, tir.frame.ForFrame): if isinstance(it, SerialForWithStep):
raise TypeError( # Validate and compute the trip count before constructing the frame
f"Invalid for loop, got {it}({type(it)}), expect one of the following: " if isinstance(it.step, (int, IntImm)):
"range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding") step_value = it.step if isinstance(it.step, int) else it.step.value
with self.with_frame(it) as v: if step_value == 0:
yield v raise ValueError('Invalid stepped serial: step must be non-zero')
if step_value > 0:
real_stop = tir.ceildiv(it.stop - it.start, step_value)
else:
real_stop = tir.ceildiv(it.start - it.stop, -step_value)
else:
logger.warning(
f'Using a non-constant step `{it.step}` in stepped serial may lead to undefined behavior in tilelang'
)
real_stop = tir.ceildiv(it.stop - it.start, it.step)
real_frame = tir.serial(real_stop, annotations=it.annotations)
with self.with_frame(real_frame) as v:
IRBuilder.name('_tmp', v)
yield it.start + v * it.step
else:
if not isinstance(it, tir.frame.ForFrame):
raise TypeError(
f"Invalid for loop, got {it}({type(it)}), expect one of the following: "
"range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding")
with self.with_frame(it) as v:
yield v
def ctx_continue(self): def ctx_continue(self):
self.check_continue_break() self.check_continue_break()
...@@ -459,8 +487,9 @@ class Builder(BaseBuilder): ...@@ -459,8 +487,9 @@ class Builder(BaseBuilder):
f"Unsupported argument type: {value}({type(value)}) for argument `{name}`.") f"Unsupported argument type: {value}({type(value)}) for argument `{name}`.")
def override(self, name: str): def override(self, name: str):
from tilelang.language import serial
if name == 'range': if name == 'range':
return tir.serial return serial
raise ValueError(f'Unknown override: {name}') raise ValueError(f'Unknown override: {name}')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment