Unverified Commit 388ee7ee authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Implement dynamic unroll factor in CUDA code generation (#1360)

* [Enhancement] Implement dynamic unroll factor in CUDA code generation

This commit introduces support for specifying a dynamic unroll factor in the CUDA code generation. The `unroll_factor` map is added to store unroll factors for loop variables, allowing for more flexible and optimized loop unrolling. Additionally, the `unroll` function is integrated into the loop language, enabling users to define unroll factors directly in their code. This enhancement improves performance by allowing tailored unrolling strategies based on specific loop characteristics.

* lint fix

* [Bugfix] Correct initialization of non-zero counters in custom compress kernel and update TIR registration for gemm_sp_py to use the correct tile operation
parent e547d247
......@@ -275,8 +275,9 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
for tm in T.Parallel(block_M):
for g_i in range(0, block_K // group):
a_k = g_i * group
T.clear(non_zero_cnt)
T.clear(non_zero_elt_log_idx)
non_zero_cnt[0] = 0
for i in range(elem):
non_zero_elt_log_idx[i] = 0
for i in range(group):
val = A_shared[tm, a_k + i]
if val != 0.0:
......
......@@ -312,7 +312,12 @@ std::string CodeGenTileLangCUDA::Finish() {
void CodeGenTileLangCUDA::VisitStmt_(const tir::ForNode *op) {
if (op->kind == tir::ForKind::kUnrolled) {
PrintIndent();
stream << "#pragma unroll\n";
if (unroll_factor.count(op->loop_var.get())) {
stream << "#pragma unroll "
<< PrintExpr(unroll_factor[op->loop_var.get()]) << "\n";
} else {
stream << "#pragma unroll\n";
}
}
std::string extent =
PrintExpr(arith::Analyzer().Simplify(op->extent + op->min));
......@@ -2661,7 +2666,12 @@ void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode *op) {
this->stream << "const dim3 blockIdx = " << pattern->value << "();\n";
this->VisitStmt(op->body);
return;
} else if (op->attr_key == "pragma_unroll_factor") {
const IntImmNode *factor = op->value.as<IntImmNode>();
ICHECK(factor);
unroll_factor[op->node.as<VarNode>()] = Downcast<IntImm>(factor);
}
CodeGenC::VisitStmt_(op);
}
......
......@@ -140,6 +140,7 @@ private:
std::unordered_map<const VarNode *, std::string> fragment_shapes;
std::unordered_map<const VarNode *, std::string> fragment_layouts;
std::unordered_map<const VarNode *, IntImm> unroll_factor;
friend void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangCUDA *p);
void PrintWmmaScope(const std::string &scope, DataType t,
......
import tilelang.testing
from tilelang import tvm as tvm
from tilelang import language as T
def test_unroll_with_step():
@T.prim_func
def main(A_ptr: T.handle):
A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)
for _blockIdx in T.thread_binding(1, thread="blockIdx.x"):
for _threadIdx in T.thread_binding(128, thread="threadIdx.x"):
for i in T.unroll(0, 16, step=4):
A[0, i] = 1.0
kernel = tilelang.compile(main, target="cuda")
assert "#pragma unroll" in kernel.get_kernel_source()
def test_unroll_with_unroll_factor():
@T.prim_func
def main(A_ptr: T.handle):
A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)
for _blockIdx in T.thread_binding(1, thread="blockIdx.x"):
for _threadIdx in T.thread_binding(128, thread="threadIdx.x"):
for i in T.unroll(0, 16, unroll_factor=4):
A[0, i] = 1.0
kernel = tilelang.compile(main, target="cuda")
assert "#pragma unroll 4" in kernel.get_kernel_source()
if __name__ == "__main__":
tilelang.testing.main()
......@@ -24,7 +24,15 @@ from .proxy import (
LocalBuffer, # noqa: F401
Ref, # noqa: F401
)
from .loop import serial, Parallel, Persistent, Pipelined # noqa: F401
from .loop import (
Parallel, # noqa: F401
Persistent, # noqa: F401
Pipelined, # noqa: F401
serial, # noqa: F401
unroll, # noqa: F401
Serial, # noqa: F401
Unroll, # noqa: F401
)
from .frame import has_let_value, get_let_value # noqa: F401
from .math_intrinsics import * # noqa: F401
from .kernel import (
......
......@@ -198,7 +198,7 @@ def gemm_sp_v2(
C_arg = buffer_region_to_tile_region(C_region, "rw", [r for r in C_shape])
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.gemm_sp_py"),
tir.op.Op.get("tl.tileop.gemm_sp_py"),
A_arg,
E_arg,
B_arg,
......
......@@ -4,8 +4,9 @@ from typing import Any
from tvm import tir
from tvm.tir import IntImm
import tvm.script.ir_builder.tir as tb_tir
from .v2.builder import SerialForWithStep
from .v2.builder import SerialForWithStep, UnrollForWithStep
from tilelang import _ffi_api
from tvm.script.ir_builder.tir import frame
def Parallel(*extents: tir.PrimExpr, coalesced_width: int | None = None):
......@@ -97,7 +98,7 @@ def serial(start: tir.PrimExpr,
stop: tir.PrimExpr | None = None,
step: tir.PrimExpr | None = None,
*,
annotations: dict[str, Any] | None = None):
annotations: dict[str, Any] | None = None) -> frame.ForFrame:
step_is_one = False
step_is_one |= isinstance(step, int) and step == 1
step_is_one |= isinstance(step, IntImm) and step.value == 1
......@@ -108,3 +109,70 @@ def serial(start: tir.PrimExpr,
stop = start
start = IntImm(start.dtype, 0) if hasattr(start, "dtype") else 0
return SerialForWithStep(start, stop, step, annotations=annotations)
def unroll(start: tir.PrimExpr,
stop: tir.PrimExpr | None = None,
step: tir.PrimExpr | None = None,
*,
explicit: bool = False,
unroll_factor: int | None = None,
annotations: dict[str, Any] | None = None) -> frame.ForFrame:
"""The unrolled For statement.
Parameters
----------
start : PrimExpr
The minimum value of iteration.
stop : PrimExpr
The maximum value of iteration.
step : PrimExpr
The step size of the iteration.
explicit : bool
Whether to explicitly unroll the loop.
unroll_factor : int
The unroll factor of the loop.
annotations : Dict[str, Any]
The optional annotations of the For statement.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
step_is_one = False
if stop is None:
stop = start
if hasattr(start, "dtype"):
start = IntImm(start.dtype, 0)
else:
start = 0
# Ensure annotations has {"pragma_unroll_explicit": True} by default
if annotations is None:
annotations = {"pragma_unroll_explicit": explicit}
else:
# Add "pragma_unroll_explicit": True if not already present
annotations = dict(annotations)
annotations.setdefault("pragma_unroll_explicit", explicit)
if unroll_factor is not None:
# check pragma_unroll_explicit must be False
if annotations.get("pragma_unroll_explicit", True):
raise ValueError("pragma_unroll_explicit must be True when unroll_factor is not None")
annotations.update({"pragma_unroll_factor": unroll_factor})
if step is None or step_is_one:
return tb_tir.unroll(start, stop, annotations=annotations)
else:
return UnrollForWithStep(start, stop, step, annotations=annotations)
Serial = serial
Unroll = unroll
......@@ -112,6 +112,11 @@ class SerialForWithStep:
annotations: dict[str, Any] | None = None
@dataclass
class UnrollForWithStep(SerialForWithStep):
...
# Python 3.9 compatibility: avoid PEP 604 unions at runtime
# Use tuple for isinstance checks and typing.Union for annotations/aliases
ContinueOrBreak = (ContinueFrame, BreakFrame)
......@@ -270,7 +275,7 @@ class Builder(BaseBuilder):
def ctx_for(self, it):
self.check_continue_break()
it = unwrap_expr(it)
if isinstance(it, SerialForWithStep):
if isinstance(it, (SerialForWithStep, UnrollForWithStep)):
# Validate and compute the trip count before constructing the frame
if isinstance(it.step, (int, IntImm)):
step_value = it.step if isinstance(it.step, int) else it.step.value
......@@ -285,7 +290,14 @@ class Builder(BaseBuilder):
f'Using a non-constant step `{it.step}` in stepped serial may lead to undefined behavior in tilelang'
)
real_stop = tir.ceildiv(it.stop - it.start, it.step)
real_frame = tir.serial(real_stop, annotations=it.annotations)
if isinstance(it, UnrollForWithStep):
real_frame = tir.unroll(real_stop, annotations=it.annotations)
elif isinstance(it, SerialForWithStep):
real_frame = tir.serial(real_stop, annotations=it.annotations)
else:
raise TypeError(
f"Invalid for loop, got {it}({type(it)}), expect one of the following: "
"range, T.serial, T.unroll, T.grid, T.parallel, T.vectorized, T.thread_binding")
with self.with_frame(real_frame) as v:
IRBuilder.name('_tmp', v)
yield it.start + v * it.step
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment