"...resnet50_tensorflow.git" did not exist on "d1988e3e92ad5f93ed7279a051c3cf25e3b4f8d4"
Unverified Commit 9e67b861 authored by Chaofan Lin's avatar Chaofan Lin Committed by GitHub
Browse files

[Language][UX] Nested loop checker in pre-lowering stage (#1288)

* [Language][UX] Nested loop checker in pre-lowering stage

* rename

* comment

* address comments
parent 49f35393
...@@ -93,7 +93,8 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer, ...@@ -93,7 +93,8 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer,
} }
for (int i = 0; i < old_loop_depth; i++) { for (int i = 0; i < old_loop_depth; i++) {
const ForNode *loop = body.as<ForNode>(); const ForNode *loop = body.as<ForNode>();
ICHECK(loop != nullptr); ICHECK(loop != nullptr)
<< "No extra statements are allowed between nested parallel loops.";
vmap.Set(loop->loop_var, indices[i]); vmap.Set(loop->loop_var, indices[i]);
loop_mins.push_back(loop->min); loop_mins.push_back(loop->min);
loop_extents.push_back(loop->extent); loop_extents.push_back(loop->extent);
......
import tilelang
import tilelang.language as T
import torch
import tilelang.testing
import pytest
tilelang.testing.set_random_seed()
def _require_cuda_tensor(shape, dtype=torch.float32):
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
try:
return torch.randn(*shape, device="cuda", dtype=dtype)
except RuntimeError as err:
pytest.skip(f"CUDA runtime unavailable: {err}")
"""
Nested Parallel cases:
T.Parallel
T.Parallel
Rule:
- continuous parallels is allowed and will be merged into one T.Parallel.
- Non-continuous (e.g. with some statements in the outer-loop) are forbidden.
"""
@tilelang.jit(out_idx=[1])
def nested_continuous_parallels(length=256, block=16, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype),
):
with T.Kernel(1, threads=length) as _:
for i in T.Parallel(length // block):
for j in T.Parallel(block):
B[i * block + j] = A[i * block + j] + 1.0
return main
@tilelang.jit(out_idx=[1])
def nested_triple_continuous_parallels(length=256, block1=8, block2=2, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype),
):
with T.Kernel(1, threads=length) as _:
for i in T.Parallel(length // block1 // block2):
for j in T.Parallel(block1):
for k in T.Parallel(block2):
B[i * block1 * block2 + j * block2 +
k] = A[i * block1 * block2 + j * block2 + k] + 1.0
return main
@tilelang.jit(out_idx=[1])
def nested_noncontinuous_parallels(length=256, block=16, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype),
):
with T.Kernel(1, threads=length) as _:
for i in T.Parallel(length // block):
B[i] = 0
for j in T.Parallel(block):
B[i * block + j] = A[i * block + j] + 1.0
return main
def test_nested_parallels():
kernel1 = nested_continuous_parallels(length=256, block=16)
kernel2 = nested_triple_continuous_parallels(length=256, block1=8, block2=2)
data = _require_cuda_tensor((256,), torch.float32)
result1 = kernel1(data)
result2 = kernel2(data)
torch.testing.assert_close(result1, data + 1.0, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(result2, data + 1.0, atol=1e-5, rtol=1e-5)
# This is invalid
with pytest.raises(ValueError):
nested_noncontinuous_parallels(length=256, block=16)
"""
Nested Pipeline cases:
T.Pipeline
T.Pipeline
is OK.
"""
def matmul_nested_pipelines(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype,
out_dtype, accum_dtype, threads, order, stage, extra_pipeline_repeats):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
for _ in T.Pipelined(extra_pipeline_repeats):
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), order=order, stage=stage):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_nested_pipelines(
order,
stage,
extra_pipeline_repeats,
):
M = 1024
N = 1024
K = 1024
block_M = 128
block_N = 128
block_K = 32
trans_A = False
trans_B = False
in_dtype = "float16"
out_dtype = "float16"
dtypeAccum = "float32"
num_threads = 128
program = matmul_nested_pipelines(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_threads,
order,
stage,
extra_pipeline_repeats,
)
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
profiler = kernel.get_profiler()
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
if in_dtype == "float32":
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas
A = ((A.view(torch.int32) - 0x1000)).view(torch.float32)
B = ((B.view(torch.int32) - 0x1000)).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_nested_pipelines():
run_gemm_nested_pipelines(order=[0, 1, 2], stage=[0, 0, 1], extra_pipeline_repeats=3)
"""
Nested serial cases:
T.serial
T.serial
is OK.
"""
@tilelang.jit(out_idx=[1])
def nested_continuous_serials(length=256, block=16, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype),
):
with T.Kernel(1, threads=length) as _:
for i in T.serial(length // block):
for j in T.serial(block):
B[i * block + j] = A[i * block + j] + 1.0
return main
@tilelang.jit(out_idx=[1])
def nested_noncontinuous_serials(length=256, block=16, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype),
):
with T.Kernel(1, threads=length) as _:
for i in T.serial(length // block):
B[i] = 0
for j in T.serial(block):
B[i * block + j] = A[i * block + j] + 1.0
return main
def test_nested_serials():
kernel1 = nested_continuous_serials(length=256, block=16)
data = _require_cuda_tensor((256,), torch.float32)
result1 = kernel1(data)
torch.testing.assert_close(result1, data + 1.0, atol=1e-5, rtol=1e-5)
# This is valid
nested_noncontinuous_serials(length=256, block=16)
"""
Mixed serial and Parallel loops:
(S-P)
T.serial
T.Parallel
(P-S)
T.Parallel
T.serial
Rule:
- No Parallel - * - Parallel
"""
@tilelang.jit(out_idx=[1])
def nested_continuous_sp(length=256, block=16, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype),
):
with T.Kernel(1, threads=length) as _:
for i in T.serial(length // block):
for j in T.Parallel(block):
B[i * block + j] = A[i * block + j] + 1.0
return main
@tilelang.jit(out_idx=[1])
def nested_continuous_ps(length=256, block=16, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype),
):
with T.Kernel(1, threads=length) as _:
for i in T.Parallel(length // block):
for j in T.serial(block):
B[i * block + j] = A[i * block + j] + 1.0
return main
@tilelang.jit(out_idx=[1])
def nested_continuous_psp(length=256, block1=8, block2=2, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype),
):
with T.Kernel(1, threads=length) as _:
for i in T.Parallel(length // block1 // block2):
for j in T.serial(block1):
for k in T.Parallel(block2):
B[i * block1 * block2 + j * block2 +
k] = A[i * block1 * block2 + j * block2 + k] + 1.0
return main
@tilelang.jit(out_idx=[1])
def nested_continuous_sps(length=256, block1=8, block2=2, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype),
):
with T.Kernel(1, threads=length) as _:
for i in T.serial(length // block1 // block2):
for j in T.Parallel(block1):
for k in T.serial(block2):
B[i * block1 * block2 + j * block2 +
k] = A[i * block1 * block2 + j * block2 + k] + 1.0
return main
def test_mixed_sp():
kernel1 = nested_continuous_sp(length=256, block=16)
kernel2 = nested_continuous_ps(length=256, block=16)
data = _require_cuda_tensor((256,), torch.float32)
result1 = kernel1(data)
result2 = kernel2(data)
torch.testing.assert_close(result1, data + 1.0, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(result2, data + 1.0, atol=1e-5, rtol=1e-5)
# This should be invalid (Undefined behaviour)
with pytest.raises(ValueError):
nested_continuous_psp(length=256, block1=16, block2=8)
kernel3 = nested_continuous_sps(length=256, block1=8, block2=2)
result3 = kernel3(data)
torch.testing.assert_close(result3, data + 1.0, atol=1e-5, rtol=1e-5)
"""
Mixed Pipelined and Parallel loops:
(Pi-Pa)
T.Pipelined
T.Parallel
(Pa-Pi)
T.Parallel
T.Pipelined
Rule:
- Pi-Pa is ok where Pa-Pi is not allowed.
- For more nested cases, refer to the rule of T.Parallel.
"""
def matmul_nested_pipa(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
accum_dtype,
threads,
order,
stage,
):
A_shape = (M, K)
B_shape = (K, N)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_K, block_N)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), order=order, stage=stage):
for i, j in T.Parallel(block_M, block_K):
A_shared[i, j] = A[by * block_M + i, k * block_K + j]
for i, j in T.Parallel(block_K, block_N):
B_shared[i, j] = B[k * block_K + i, bx * block_N + j]
# T.copy(A[by * block_M, k * block_K], A_shared)
# T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, False, False)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def matmul_nested_papipa(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
accum_dtype,
threads,
order,
stage,
):
A_shape = (M, K)
B_shape = (K, N)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_K, block_N)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for _ in T.Parallel(1):
for k in T.Pipelined(T.ceildiv(K, block_K), order=order, stage=stage):
for i, j in T.Parallel(block_M, block_K):
A_shared[i, j] = A[by * block_M + i, k * block_K + j]
for i, j in T.Parallel(block_K, block_N):
B_shared[i, j] = B[k * block_K + i, bx * block_N + j]
# T.copy(A[by * block_M, k * block_K], A_shared)
# T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, False, False)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_mixed_pp(
order,
stage,
):
M = 1024
N = 1024
K = 1024
block_M = 128
block_N = 128
block_K = 32
in_dtype = "float16"
out_dtype = "float16"
dtypeAccum = "float32"
num_threads = 128
program = matmul_nested_pipa(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
dtypeAccum,
num_threads,
order,
stage,
)
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
profiler = kernel.get_profiler()
def ref_program(A, B):
import torch
if in_dtype == "float32":
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas
A = ((A.view(torch.int32) - 0x1000)).view(torch.float32)
B = ((B.view(torch.int32) - 0x1000)).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
program1 = matmul_nested_papipa(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
dtypeAccum,
num_threads,
order,
stage,
)
with pytest.raises(ValueError):
tilelang.compile(
program1,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
def test_mixed_pp():
run_gemm_mixed_pp(order=[0, 1, 2], stage=[0, 0, 1])
if __name__ == "__main__":
tilelang.testing.main()
...@@ -133,6 +133,7 @@ from .layout import ( ...@@ -133,6 +133,7 @@ from .layout import (
Fragment, # noqa: F401 Fragment, # noqa: F401
) )
from . import ( from . import (
analysis, # noqa: F401
transform, # noqa: F401 transform, # noqa: F401
language, # noqa: F401 language, # noqa: F401
engine, # noqa: F401 engine, # noqa: F401
......
"""Tilelang IR analysis & visitors."""
from .nested_loop_checker import NestedLoopChecker # noqa: F401
from tvm import tir
from tvm.tir import (
For,
PrimFunc,
PyStmtExprVisitor,
)
from tvm.tir.transform import prim_func_pass
def is_pipelined_for(op: For) -> bool:
"""Check if a for loop is pipelined."""
anno_keys = [
"num_stages", "tl_pipeline_order", "tl_pipeline_stage", "tl_pipeline_sync",
"tl_pipeline_group"
]
return any(key in op.annotations for key in anno_keys)
@tir.functor.visitor
class _NestedLoopCheckVisitor(PyStmtExprVisitor):
def __init__(self) -> None:
super().__init__()
self.in_parallel_context = False
def visit_for_(self, op: For) -> None:
if op.kind == tir.ForKind.PARALLEL:
child = op.body
# Special case: continuous nested parallel loop is allowed.
if isinstance(child, tir.For) and child.kind == tir.ForKind.PARALLEL:
self.visit_stmt(child)
return
# Otherwise
if self.in_parallel_context:
raise ValueError("Nested parallel loops are not allowed. "
"Please check your loop structure.")
self.in_parallel_context = True
self.visit_stmt(child)
self.in_parallel_context = False
return
elif is_pipelined_for(op):
if self.in_parallel_context:
raise ValueError("Pipelined loop cannot be nested inside a parallel loop. "
"Please check your loop structure.")
self.visit_stmt(op.body)
def NestedLoopChecker():
"""
User-friendly pass which identifies any invalid any nested-loop pattern.
Nested loops is an annoying problem in tilelang or other polyhedral-style compilers.
It contains many corner cases and undefined behaviours.
In tilelang, there are four loops:
T.serial
T.Parallel (T.vectorized)
T.Pipelined
T.Persistent
T.Persistent is a new feature which we do not consider here.
We define the following rules:
- (Rule 1) T.serial can be nested inside any other loop type without restriction.
- (Rule 2) Consecutive T.Parallel nested loops are not allowed. Including any TileOp (T.copy, etc.) which has
"parallel" behaviours is also forbidden.
Examples:
for i in T.Parallel(M):
stmt
for j in T.Parallel(N):
...
for i in T.Parallel(M):
T.copy(A, B) # forbidden!
**Only a special case is allowed: strict continuous Parallel loops.** Since we can fuse them into a single T.Parallel loop.
Example:
for i in T.Parallel(M):
for j in T.Parallel(N):
... # allowed
- (Rule 3) T.Pipelined inside a T.Parallel is forbidden.
Examples:
for i in T.Parallel(M):
for j in T.Pipelined(K): # forbidden!
...
for i in T.Pipelined(K):
for j in T.Parallel(N): # allowed, ok
...
In summary, the problem mainly lies in the "T.Parallel". We highly recommend to use
T.Parallel to implement a tiled operator inside a kernel (e.g. T.gemm level) instead of other usages.
This guideline can help you avoid most of the issues.
Returns:
A prim_func_pass that applies the transformation
"""
def pass_fn(func: PrimFunc, mod, ctx):
_NestedLoopCheckVisitor().visit_stmt(func.body)
return func
return prim_func_pass(pass_fn, opt_level=0)
...@@ -16,6 +16,7 @@ from tilelang.utils.deprecated import deprecated_warning ...@@ -16,6 +16,7 @@ from tilelang.utils.deprecated import deprecated_warning
from tilelang.engine.param import KernelParam, CompiledArtifact from tilelang.engine.param import KernelParam, CompiledArtifact
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
from tilelang.engine.phase import ( from tilelang.engine.phase import (
PreLowerSemanticCheck,
LowerAndLegalize, LowerAndLegalize,
OptimizeForTarget, OptimizeForTarget,
) )
...@@ -242,6 +243,9 @@ def lower( ...@@ -242,6 +243,9 @@ def lower(
_is_host_call = get_host_call(is_device_c=is_cpu_device_backend(target)) _is_host_call = get_host_call(is_device_c=is_cpu_device_backend(target))
_is_device_call = get_device_call(is_device_c=is_cpu_device_backend(target)) _is_device_call = get_device_call(is_device_c=is_cpu_device_backend(target))
# Before lowering, do semantic check
PreLowerSemanticCheck(mod)
# Phase 1: Lower and legalize the IR # Phase 1: Lower and legalize the IR
mod = LowerAndLegalize(mod, target) mod = LowerAndLegalize(mod, target)
......
...@@ -67,6 +67,17 @@ def should_force_let_inline(pass_ctx: PassContext | None = None) -> bool: ...@@ -67,6 +67,17 @@ def should_force_let_inline(pass_ctx: PassContext | None = None) -> bool:
return bool(pass_ctx and pass_ctx.config.get(tilelang.PassConfigKey.TL_FORCE_LET_INLINE, False)) return bool(pass_ctx and pass_ctx.config.get(tilelang.PassConfigKey.TL_FORCE_LET_INLINE, False))
def PreLowerSemanticCheck(mod: IRModule) -> None:
"""
Check whether the module is valid before lowering. If not, raise a user-friendly error
in Python side instead of letting the error dive into the complicated TVM/C++ stack.
Note: This is a validation-only pipeline of passes and does not modify or return the module.
"""
# Check if there are any invalid nested loops.
tilelang.analysis.NestedLoopChecker()(mod)
def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
# Bind the target device information to the module # Bind the target device information to the module
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment