Unverified Commit 388ee7ee authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Implement dynamic unroll factor in CUDA code generation (#1360)

* [Enhancement] Implement dynamic unroll factor in CUDA code generation

This commit introduces support for specifying a dynamic unroll factor in the CUDA code generation. The `unroll_factor` map is added to store unroll factors for loop variables, allowing for more flexible and optimized loop unrolling. Additionally, the `unroll` function is integrated into the loop language, enabling users to define unroll factors directly in their code. This enhancement improves performance by allowing tailored unrolling strategies based on specific loop characteristics.

* lint fix

* [Bugfix] Correct initialization of non-zero counters in custom compress kernel and update TIR registration for gemm_sp_py to use the correct tile operation
parent e547d247
...@@ -275,8 +275,9 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout): ...@@ -275,8 +275,9 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
for tm in T.Parallel(block_M): for tm in T.Parallel(block_M):
for g_i in range(0, block_K // group): for g_i in range(0, block_K // group):
a_k = g_i * group a_k = g_i * group
T.clear(non_zero_cnt) non_zero_cnt[0] = 0
T.clear(non_zero_elt_log_idx) for i in range(elem):
non_zero_elt_log_idx[i] = 0
for i in range(group): for i in range(group):
val = A_shared[tm, a_k + i] val = A_shared[tm, a_k + i]
if val != 0.0: if val != 0.0:
......
...@@ -312,7 +312,12 @@ std::string CodeGenTileLangCUDA::Finish() { ...@@ -312,7 +312,12 @@ std::string CodeGenTileLangCUDA::Finish() {
void CodeGenTileLangCUDA::VisitStmt_(const tir::ForNode *op) { void CodeGenTileLangCUDA::VisitStmt_(const tir::ForNode *op) {
if (op->kind == tir::ForKind::kUnrolled) { if (op->kind == tir::ForKind::kUnrolled) {
PrintIndent(); PrintIndent();
stream << "#pragma unroll\n"; if (unroll_factor.count(op->loop_var.get())) {
stream << "#pragma unroll "
<< PrintExpr(unroll_factor[op->loop_var.get()]) << "\n";
} else {
stream << "#pragma unroll\n";
}
} }
std::string extent = std::string extent =
PrintExpr(arith::Analyzer().Simplify(op->extent + op->min)); PrintExpr(arith::Analyzer().Simplify(op->extent + op->min));
...@@ -2661,7 +2666,12 @@ void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode *op) { ...@@ -2661,7 +2666,12 @@ void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode *op) {
this->stream << "const dim3 blockIdx = " << pattern->value << "();\n"; this->stream << "const dim3 blockIdx = " << pattern->value << "();\n";
this->VisitStmt(op->body); this->VisitStmt(op->body);
return; return;
} else if (op->attr_key == "pragma_unroll_factor") {
const IntImmNode *factor = op->value.as<IntImmNode>();
ICHECK(factor);
unroll_factor[op->node.as<VarNode>()] = Downcast<IntImm>(factor);
} }
CodeGenC::VisitStmt_(op); CodeGenC::VisitStmt_(op);
} }
......
...@@ -140,6 +140,7 @@ private: ...@@ -140,6 +140,7 @@ private:
std::unordered_map<const VarNode *, std::string> fragment_shapes; std::unordered_map<const VarNode *, std::string> fragment_shapes;
std::unordered_map<const VarNode *, std::string> fragment_layouts; std::unordered_map<const VarNode *, std::string> fragment_layouts;
std::unordered_map<const VarNode *, IntImm> unroll_factor;
friend void PrintConst(const FloatImmNode *op, std::ostream &os, friend void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangCUDA *p); CodeGenTileLangCUDA *p);
void PrintWmmaScope(const std::string &scope, DataType t, void PrintWmmaScope(const std::string &scope, DataType t,
......
import tilelang.testing
from tilelang import tvm as tvm
from tilelang import language as T
def test_unroll_with_step():
@T.prim_func
def main(A_ptr: T.handle):
A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)
for _blockIdx in T.thread_binding(1, thread="blockIdx.x"):
for _threadIdx in T.thread_binding(128, thread="threadIdx.x"):
for i in T.unroll(0, 16, step=4):
A[0, i] = 1.0
kernel = tilelang.compile(main, target="cuda")
assert "#pragma unroll" in kernel.get_kernel_source()
def test_unroll_with_unroll_factor():
@T.prim_func
def main(A_ptr: T.handle):
A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)
for _blockIdx in T.thread_binding(1, thread="blockIdx.x"):
for _threadIdx in T.thread_binding(128, thread="threadIdx.x"):
for i in T.unroll(0, 16, unroll_factor=4):
A[0, i] = 1.0
kernel = tilelang.compile(main, target="cuda")
assert "#pragma unroll 4" in kernel.get_kernel_source()
if __name__ == "__main__":
tilelang.testing.main()
...@@ -24,7 +24,15 @@ from .proxy import ( ...@@ -24,7 +24,15 @@ from .proxy import (
LocalBuffer, # noqa: F401 LocalBuffer, # noqa: F401
Ref, # noqa: F401 Ref, # noqa: F401
) )
from .loop import serial, Parallel, Persistent, Pipelined # noqa: F401 from .loop import (
Parallel, # noqa: F401
Persistent, # noqa: F401
Pipelined, # noqa: F401
serial, # noqa: F401
unroll, # noqa: F401
Serial, # noqa: F401
Unroll, # noqa: F401
)
from .frame import has_let_value, get_let_value # noqa: F401 from .frame import has_let_value, get_let_value # noqa: F401
from .math_intrinsics import * # noqa: F401 from .math_intrinsics import * # noqa: F401
from .kernel import ( from .kernel import (
......
...@@ -198,7 +198,7 @@ def gemm_sp_v2( ...@@ -198,7 +198,7 @@ def gemm_sp_v2(
C_arg = buffer_region_to_tile_region(C_region, "rw", [r for r in C_shape]) C_arg = buffer_region_to_tile_region(C_region, "rw", [r for r in C_shape])
return tir.call_intrin( return tir.call_intrin(
"handle", "handle",
tir.op.Op.get("tl.gemm_sp_py"), tir.op.Op.get("tl.tileop.gemm_sp_py"),
A_arg, A_arg,
E_arg, E_arg,
B_arg, B_arg,
......
...@@ -4,8 +4,9 @@ from typing import Any ...@@ -4,8 +4,9 @@ from typing import Any
from tvm import tir from tvm import tir
from tvm.tir import IntImm from tvm.tir import IntImm
import tvm.script.ir_builder.tir as tb_tir import tvm.script.ir_builder.tir as tb_tir
from .v2.builder import SerialForWithStep from .v2.builder import SerialForWithStep, UnrollForWithStep
from tilelang import _ffi_api from tilelang import _ffi_api
from tvm.script.ir_builder.tir import frame
def Parallel(*extents: tir.PrimExpr, coalesced_width: int | None = None): def Parallel(*extents: tir.PrimExpr, coalesced_width: int | None = None):
...@@ -97,7 +98,7 @@ def serial(start: tir.PrimExpr, ...@@ -97,7 +98,7 @@ def serial(start: tir.PrimExpr,
stop: tir.PrimExpr | None = None, stop: tir.PrimExpr | None = None,
step: tir.PrimExpr | None = None, step: tir.PrimExpr | None = None,
*, *,
annotations: dict[str, Any] | None = None): annotations: dict[str, Any] | None = None) -> frame.ForFrame:
step_is_one = False step_is_one = False
step_is_one |= isinstance(step, int) and step == 1 step_is_one |= isinstance(step, int) and step == 1
step_is_one |= isinstance(step, IntImm) and step.value == 1 step_is_one |= isinstance(step, IntImm) and step.value == 1
...@@ -108,3 +109,70 @@ def serial(start: tir.PrimExpr, ...@@ -108,3 +109,70 @@ def serial(start: tir.PrimExpr,
stop = start stop = start
start = IntImm(start.dtype, 0) if hasattr(start, "dtype") else 0 start = IntImm(start.dtype, 0) if hasattr(start, "dtype") else 0
return SerialForWithStep(start, stop, step, annotations=annotations) return SerialForWithStep(start, stop, step, annotations=annotations)
def unroll(start: tir.PrimExpr,
stop: tir.PrimExpr | None = None,
step: tir.PrimExpr | None = None,
*,
explicit: bool = False,
unroll_factor: int | None = None,
annotations: dict[str, Any] | None = None) -> frame.ForFrame:
"""The unrolled For statement.
Parameters
----------
start : PrimExpr
The minimum value of iteration.
stop : PrimExpr
The maximum value of iteration.
step : PrimExpr
The step size of the iteration.
explicit : bool
Whether to explicitly unroll the loop.
unroll_factor : int
The unroll factor of the loop.
annotations : Dict[str, Any]
The optional annotations of the For statement.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
step_is_one = False
if stop is None:
stop = start
if hasattr(start, "dtype"):
start = IntImm(start.dtype, 0)
else:
start = 0
# Ensure annotations has {"pragma_unroll_explicit": True} by default
if annotations is None:
annotations = {"pragma_unroll_explicit": explicit}
else:
# Add "pragma_unroll_explicit": True if not already present
annotations = dict(annotations)
annotations.setdefault("pragma_unroll_explicit", explicit)
if unroll_factor is not None:
# check pragma_unroll_explicit must be False
if annotations.get("pragma_unroll_explicit", True):
raise ValueError("pragma_unroll_explicit must be True when unroll_factor is not None")
annotations.update({"pragma_unroll_factor": unroll_factor})
if step is None or step_is_one:
return tb_tir.unroll(start, stop, annotations=annotations)
else:
return UnrollForWithStep(start, stop, step, annotations=annotations)
Serial = serial
Unroll = unroll
...@@ -112,6 +112,11 @@ class SerialForWithStep: ...@@ -112,6 +112,11 @@ class SerialForWithStep:
annotations: dict[str, Any] | None = None annotations: dict[str, Any] | None = None
@dataclass
class UnrollForWithStep(SerialForWithStep):
...
# Python 3.9 compatibility: avoid PEP 604 unions at runtime # Python 3.9 compatibility: avoid PEP 604 unions at runtime
# Use tuple for isinstance checks and typing.Union for annotations/aliases # Use tuple for isinstance checks and typing.Union for annotations/aliases
ContinueOrBreak = (ContinueFrame, BreakFrame) ContinueOrBreak = (ContinueFrame, BreakFrame)
...@@ -270,7 +275,7 @@ class Builder(BaseBuilder): ...@@ -270,7 +275,7 @@ class Builder(BaseBuilder):
def ctx_for(self, it): def ctx_for(self, it):
self.check_continue_break() self.check_continue_break()
it = unwrap_expr(it) it = unwrap_expr(it)
if isinstance(it, SerialForWithStep): if isinstance(it, (SerialForWithStep, UnrollForWithStep)):
# Validate and compute the trip count before constructing the frame # Validate and compute the trip count before constructing the frame
if isinstance(it.step, (int, IntImm)): if isinstance(it.step, (int, IntImm)):
step_value = it.step if isinstance(it.step, int) else it.step.value step_value = it.step if isinstance(it.step, int) else it.step.value
...@@ -285,7 +290,14 @@ class Builder(BaseBuilder): ...@@ -285,7 +290,14 @@ class Builder(BaseBuilder):
f'Using a non-constant step `{it.step}` in stepped serial may lead to undefined behavior in tilelang' f'Using a non-constant step `{it.step}` in stepped serial may lead to undefined behavior in tilelang'
) )
real_stop = tir.ceildiv(it.stop - it.start, it.step) real_stop = tir.ceildiv(it.stop - it.start, it.step)
real_frame = tir.serial(real_stop, annotations=it.annotations) if isinstance(it, UnrollForWithStep):
real_frame = tir.unroll(real_stop, annotations=it.annotations)
elif isinstance(it, SerialForWithStep):
real_frame = tir.serial(real_stop, annotations=it.annotations)
else:
raise TypeError(
f"Invalid for loop, got {it}({type(it)}), expect one of the following: "
"range, T.serial, T.unroll, T.grid, T.parallel, T.vectorized, T.thread_binding")
with self.with_frame(real_frame) as v: with self.with_frame(real_frame) as v:
IRBuilder.name('_tmp', v) IRBuilder.name('_tmp', v)
yield it.start + v * it.step yield it.start + v * it.step
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment