Commit 47caf219 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Support `T.clear` for let binding (#268)

* Fix indentation in JIT adapter wrapper to ensure consistent formatting of return statement in generated C code.

* Enhance Fill Operation in TileLang

- Updated the Fill constructor to support BufferLoad instances, adding checks for ramp indices and ensuring only stride 1 ramps are processed.
- Introduced a region array to manage the bounds of the fill operation, improving error checking for static regions.
- Modified the MakeSIMTLoop method to utilize the new region array for loop variable bounds, enhancing flexibility in kernel generation.
- Updated the fill and clear functions in fill.py to accept both tir.Buffer and tir.BufferRegion types, improving usability and type handling.

* Refactor Fill Operation and Improve Readability

- Simplified the Fill constructor by enhancing the handling of BufferLoad instances and ensuring proper checks for ramp indices.
- Improved error messages for region size checks to enhance clarity.
- Cleaned up formatting in the Fill method for better readability.
- Added a blank line in the matmul function test to improve code organization.
- Introduced a blank line in the fill function to enhance readability in fill.py.

* Add matrix multiplication functionality and test in TileLang

- Introduced a new test file `test_tilelang_language_clear.py` that implements a matrix multiplication function using TileLang's primitives.
- The `matmul` function defines a kernel for performing tile-level GEMM operations with customizable block sizes and data types.
- Added a `run_matmul` function to compile and execute the kernel, along with a test function to validate the implementation.
- Updated the `__init__.py` in the utils module to include `map_torch_type`, enhancing type handling for tensor operations.

* lint fix
parent 9981ac59
...@@ -355,12 +355,49 @@ LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -355,12 +355,49 @@ LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) {
} }
Fill::Fill(Array<PrimExpr> args, BufferMap vmap) { Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
dst = vmap[GetVarFromAccessPtr(args[0])];
if (args[0]->IsInstance<BufferLoadNode>()) {
auto buffer_load = Downcast<BufferLoad>(args[0]);
for (const auto &index : buffer_load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
CHECK(ramp->stride.as<IntImmNode>()->value == 1)
<< "Only stride 1 ramps are supported";
const auto *lanes = ramp->lanes.as<IntImmNode>();
CHECK(lanes)
<< "Scalable vectors not supported in BufferRegion conversion";
region.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
region.push_back(Range::FromMinExtent(index, 1));
}
}
dst = buffer_load->buffer;
} else {
dst = vmap[GetVarFromAccessPtr(args[0])];
for (int i = 0; i < dst->shape.size(); i++) {
region.push_back(Range(0, dst->shape[i]));
}
}
if (args[1]->dtype != dst->dtype) { if (args[1]->dtype != dst->dtype) {
value = Cast(dst->dtype, args[1]); value = Cast(dst->dtype, args[1]);
} else { } else {
value = args[1]; value = args[1];
} }
ICHECK(region.size() == dst->shape.size())
<< "region size = " << region.size() << " != " << dst->shape.size();
for (int i = 0; i < region.size(); i++) {
// bound check if region is static
if (region[i]->min.as<IntImm>()) {
int64_t min = Downcast<IntImm>(region[i]->min)->value;
ICHECK_GE(min, 0) << "region[" << i << "] = " << min << " < 0";
}
if (region[i]->extent.as<IntImm>()) {
int64_t extent = Downcast<IntImm>(region[i]->extent)->value;
ICHECK_LE(extent, Downcast<IntImm>(dst->shape[i])->value)
<< "region[" << i << "] = " << extent << " > " << dst->shape[i];
}
}
} }
For Fill::MakeSIMTLoop(arith::Analyzer *analyzer) const { For Fill::MakeSIMTLoop(arith::Analyzer *analyzer) const {
...@@ -369,7 +406,7 @@ For Fill::MakeSIMTLoop(arith::Analyzer *analyzer) const { ...@@ -369,7 +406,7 @@ For Fill::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Array<PrimExpr> dst_indices; Array<PrimExpr> dst_indices;
for (int i = 0; i < ndim; i++) { for (int i = 0; i < ndim; i++) {
Var var = Var(std::string{char('i' + i)}); Var var = Var(std::string{char('i' + i)});
loop_vars.push_back({Range(0, dst->shape[i]), var, IterVarType::kDataPar}); loop_vars.push_back({region[i], var, IterVarType::kDataPar});
dst_indices.push_back(var); dst_indices.push_back(var);
} }
Stmt body = BufferStore(dst, value, dst_indices); Stmt body = BufferStore(dst, value, dst_indices);
......
...@@ -56,6 +56,7 @@ private: ...@@ -56,6 +56,7 @@ private:
For MakeSIMTLoop(arith::Analyzer *analyzer) const; For MakeSIMTLoop(arith::Analyzer *analyzer) const;
tir::Buffer dst; tir::Buffer dst;
PrimExpr value; PrimExpr value;
Array<Range> region;
}; };
} // namespace tl } // namespace tl
......
...@@ -19,6 +19,8 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -19,6 +19,8 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
T.clear(C_local) T.clear(C_local)
X_shared = A_shared[:block_M, :block_K] X_shared = A_shared[:block_M, :block_K]
X_local = C_local[:block_M, :block_K]
T.clear(X_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
# Copy tile of A # Copy tile of A
......
import tilelang
import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
# Copy tile of A
# This is a sugar syntax for parallelized copy
T.copy(A[by * block_M, ko * block_K], A_shared)
T.clear(A_shared)
# Demonstrate parallelized copy from global to shared for B
T.copy(B[bx * block_N, ko * block_K], B_shared)
# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
# Copy result back to global memory
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
program = matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)
kernel = tilelang.compile(
program, out_idx=[2], target="cuda", pass_configs={"tl.disable_tma_lower": True})
import torch
from tilelang.utils import map_torch_type
a = torch.randn((M, K), dtype=map_torch_type(dtype)).cuda()
b = torch.randn((N, K), dtype=map_torch_type(dtype)).cuda()
c = kernel(a, b)
assert torch.allclose(c, torch.zeros_like(c))
def test_matmul():
run_matmul(1024, 1024, 1024, 128, 128, 32)
if __name__ == "__main__":
test_matmul()
...@@ -9,9 +9,9 @@ import logging ...@@ -9,9 +9,9 @@ import logging
import textwrap import textwrap
PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY = """ PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY = """
cudaError_t result = cudaFuncSetAttribute({0}, cudaFuncAttributeMaxDynamicSharedMemorySize, {1}); cudaError_t result_{0} = cudaFuncSetAttribute({0}, cudaFuncAttributeMaxDynamicSharedMemorySize, {1});
if (result != CUDA_SUCCESS) {{ if (result_{0} != CUDA_SUCCESS) {{
snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size to %d with error: %s", {1}, cudaGetErrorString(result)); snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size to %d with error: %s", {1}, cudaGetErrorString(result_{0}));
return -1; return -1;
}} }}
""" """
...@@ -34,7 +34,7 @@ extern "C" int init() {{ ...@@ -34,7 +34,7 @@ extern "C" int init() {{
PREDEF_HOST_FUNC = """ PREDEF_HOST_FUNC = """
extern "C" int call({}) {{ extern "C" int call({}) {{
{} {}
return 0; return 0;
}} }}
""" """
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from tvm import tir from tvm import tir
from typing import Union
from tilelang.language import has_let_value, get_let_value
def fill(buffer: tir.Buffer, value: tir.PrimExpr): def fill(buffer: Union[tir.Buffer, tir.BufferRegion], value: tir.PrimExpr):
buffer = buffer.access_ptr("w") """Fill a buffer or buffer region with a specified value.
Args:
buffer: Either a TVM buffer or buffer region to be filled
value: The value to fill the buffer with
Returns:
A TVM intrinsic call that performs the fill operation
"""
if isinstance(buffer, tir.Buffer):
buffer = buffer.access_ptr("w") # Get write pointer if input is a Buffer
return tir.call_intrin("handle", tir.op.Op.get("tl.fill"), buffer, value) return tir.call_intrin("handle", tir.op.Op.get("tl.fill"), buffer, value)
def clear(buffer: tir.Buffer): def clear(buffer: Union[tir.Buffer, tir.Var]):
"""Clear a buffer by filling it with zeros.
Args:
buffer: Either a TVM buffer or a variable that contains a buffer region
Returns:
A fill operation that sets the buffer contents to zero
Raises:
ValueError: If the buffer variable contains an invalid buffer region
"""
if isinstance(buffer, tir.Var) and has_let_value(buffer):
buffer_region = get_let_value(buffer) # Get the actual buffer region from variable
if isinstance(buffer_region, tir.BufferRegion):
return fill(buffer_region, 0)
else:
raise ValueError(f"Invalid buffer region: {buffer_region}")
return fill(buffer, 0) return fill(buffer, 0)
"""The profiler and convert to torch utils""" """The profiler and convert to torch utils"""
from .target import determine_target # noqa: F401 from .target import determine_target # noqa: F401
from .tensor import TensorSupplyType, torch_assert_close # noqa: F401 from .tensor import TensorSupplyType, torch_assert_close, map_torch_type # noqa: F401
from .language import ( from .language import (
is_global, # noqa: F401 is_global, # noqa: F401
is_shared, # noqa: F401 is_shared, # noqa: F401
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment