Unverified Commit 1b308baf authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Language] Introduce `StridedTensor` to support non contigious torch inputs (#722)



* Update submodule 'tvm' to commit e11521e6936a827efa334588d29571fbb4620107

* Support strided tensors

* Refactor target attribute helper functions for improved clarity

* No code changes made in proxy.py and setup.py

* lint fix

* lint fix via gemini

* lint fix

* test fix

* test fix

* lint fix

* Update wrapper.py

* test fix

* Enhance test for InjectSoftwarePipeline by adding LowerOpaqueBlock transformation and updating expected function signature to use match_buffer for better clarity.

* lint fix

---------
Co-authored-by: default avatarChenggang Zhao <chenggangz@deepseek.com>
parent c369d690
...@@ -7,8 +7,6 @@ import tilelang.language as T ...@@ -7,8 +7,6 @@ import tilelang.language as T
from tilelang.autotuner import * from tilelang.autotuner import *
from example_fusedmoe_torch import * from example_fusedmoe_torch import *
# tilelang.disable_cache()
@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) @tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def moe_forward_tilelang_shared(d_hidden, def moe_forward_tilelang_shared(d_hidden,
......
...@@ -145,20 +145,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -145,20 +145,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
clear_accum=True, clear_accum=True,
wg_wait=-1) wg_wait=-1)
T.barrier_wait(kv_shared_0_r_is_ready, k % 2) T.barrier_wait(kv_shared_0_r_is_ready, k % 2)
T.gemm( T.gemm(Q_shared_r, KV_shared_0_r, acc_s_0, transpose_B=True, wg_wait=-1)
Q_shared_r,
KV_shared_0_r,
acc_s_0,
transpose_B=True,
wg_wait=-1)
T.barrier_wait(kv_shared_0_pe_is_ready, k % 2) T.barrier_wait(kv_shared_0_pe_is_ready, k % 2)
T.gemm( T.gemm(Q_pe_local_0, K_pe_shared_0, acc_s_0, transpose_B=True, wg_wait=-1)
Q_pe_local_0,
K_pe_shared_0,
acc_s_0,
transpose_B=True,
wg_wait=-1)
T.wait_wgmma(0) T.wait_wgmma(0)
...@@ -261,20 +251,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -261,20 +251,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
wg_wait=-1) wg_wait=-1)
T.barrier_wait(kv_shared_1_r_is_ready, k % 2) T.barrier_wait(kv_shared_1_r_is_ready, k % 2)
T.gemm( T.gemm(Q_shared_r, KV_shared_1_r, acc_s_1, transpose_B=True, wg_wait=-1)
Q_shared_r,
KV_shared_1_r,
acc_s_1,
transpose_B=True,
wg_wait=-1)
T.barrier_wait(kv_shared_1_pe_is_ready, k % 2) T.barrier_wait(kv_shared_1_pe_is_ready, k % 2)
T.gemm( T.gemm(Q_pe_local_1, K_pe_shared_1, acc_s_1, transpose_B=True, wg_wait=-1)
Q_pe_local_1,
K_pe_shared_1,
acc_s_1,
transpose_B=True,
wg_wait=-1)
T.wait_wgmma(0) T.wait_wgmma(0)
...@@ -308,11 +288,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -308,11 +288,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
# Step 10. compute O1 with KV_shared_1_rd # Step 10. compute O1 with KV_shared_1_rd
T.copy(acc_s_1, acc_s_1_cast) T.copy(acc_s_1, acc_s_1_cast)
T.gemm( T.gemm(acc_s_1_cast, KV_shared_1_r, acc_o_r, wg_wait=-1)
acc_s_1_cast,
KV_shared_1_r,
acc_o_r,
wg_wait=-1)
T.copy(acc_s_1_cast, SP1_shared) T.copy(acc_s_1_cast, SP1_shared)
T.barrier_arrive(s_shared_ready_barrier) T.barrier_arrive(s_shared_ready_barrier)
......
import fcntl
import functools
import hashlib
import io import io
import subprocess import subprocess
import shutil import shutil
...@@ -12,9 +15,7 @@ from pathlib import Path ...@@ -12,9 +15,7 @@ from pathlib import Path
import os import os
import sys import sys
import site import site
import hashlib
import sysconfig import sysconfig
import functools
import urllib.request import urllib.request
from packaging.version import Version from packaging.version import Version
import platform import platform
...@@ -22,7 +23,6 @@ import multiprocessing ...@@ -22,7 +23,6 @@ import multiprocessing
from setuptools.command.build_ext import build_ext from setuptools.command.build_ext import build_ext
import importlib import importlib
import logging import logging
import fcntl
# Configure logging with basic settings # Configure logging with basic settings
logging.basicConfig( logging.basicConfig(
...@@ -692,15 +692,15 @@ class TilelangExtensionBuild(build_ext): ...@@ -692,15 +692,15 @@ class TilelangExtensionBuild(build_ext):
with open(md5_path, "r") as f: with open(md5_path, "r") as f:
cached_hash = f.read().strip() cached_hash = f.read().strip()
if cached_hash == code_hash: if cached_hash == code_hash:
logger.info("Cython jit adapter is up to date, no need to compile...") logger.info("Cython JIT adapter is up to date, no need to compile...")
need_compile = False need_compile = False
else: else:
logger.info("Cython jit adapter is out of date, need to recompile...") logger.info("Cython JIT adapter is out of date, need to recompile...")
else: else:
logger.info("No cached version found for cython jit adapter, need to compile...") logger.info("No cached version found for Cython JIT adapter, need to compile...")
if need_compile: if need_compile:
logger.info("Waiting for lock to compile cython jit adapter...") logger.info("Waiting for lock to compile Cython JIT adapter...")
with open(lock_file, 'w') as lock: with open(lock_file, 'w') as lock:
fcntl.flock(lock.fileno(), fcntl.LOCK_EX) fcntl.flock(lock.fileno(), fcntl.LOCK_EX)
try: try:
...@@ -715,7 +715,7 @@ class TilelangExtensionBuild(build_ext): ...@@ -715,7 +715,7 @@ class TilelangExtensionBuild(build_ext):
need_compile = False need_compile = False
if need_compile: if need_compile:
logger.info("Compiling cython jit adapter...") logger.info("Compiling Cython JIT adapter...")
temp_path = cache_dir / f"temp_{code_hash}.so" temp_path = cache_dir / f"temp_{code_hash}.so"
with open(md5_path, "w") as f: with open(md5_path, "w") as f:
...@@ -736,7 +736,7 @@ class TilelangExtensionBuild(build_ext): ...@@ -736,7 +736,7 @@ class TilelangExtensionBuild(build_ext):
except Exception as e: except Exception as e:
if 'temp_path' in locals() and temp_path.exists(): if 'temp_path' in locals() and temp_path.exists():
temp_path.unlink() temp_path.unlink()
raise Exception(f"Failed to compile cython jit adapter: {e}") from e raise Exception(f"Failed to compile Cython JIT adapter: {e}") from e
finally: finally:
if lock_file.exists(): if lock_file.exists():
lock_file.unlink() lock_file.unlink()
......
...@@ -1702,6 +1702,76 @@ void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) { ...@@ -1702,6 +1702,76 @@ void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) {
os << "))"; os << "))";
} }
void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op,
std::ostream &os) { // NOLINT(*)
ICHECK_EQ(op->indices.size(), 1)
<< "Load from non-flat memory not supported.";
ICHECK(!op->predicate.defined())
<< "Predicated buffer load is not supported.";
DataType value_dtype = op->dtype;
PrimExpr index = op->indices[0];
Var buffer_var = op->buffer->data;
DataType element_dtype = op->buffer->dtype;
int lanes = op->dtype.lanes();
// delcare type.
if (value_dtype.lanes() == element_dtype.lanes()) {
std::string ref = GetBufferRef(op->dtype, op->buffer.get(), index);
HandleVolatileLoads(ref, op, os);
} else {
bool can_vector_load = false;
arith::PVar<PrimExpr> base;
if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) {
const RampNode *ramp = index.as<RampNode>();
ICHECK(ramp);
can_vector_load = true;
// arith::ModularSet me = arith::Analyzer().modular_set(ramp->base);
// The condition: {k * coeff + base} divisible by the alignment for any k
// if (me->coeff % op->dtype.lanes() == 0 && me->base % op->dtype.lanes()
// == 0) {
// can_vector_load = true;
// }
}
if (value_dtype.is_float4_e2m1fn() && lanes != 1) {
// A float4_e2m1fn element has 4 bits, which is an incomplete byte.
// So we cannot vector load it.
can_vector_load = false;
}
if (can_vector_load) {
std::string ref = GetVecLoad(op->dtype, op->buffer.get(), base.Eval());
HandleVolatileLoads(ref, op, os);
} else {
std::ostringstream svalue_expr;
std::string sindex = SSAGetID(PrintExpr(index), index.dtype());
std::string vid = GetVarID(buffer_var.get());
DataType elem_type = op->dtype.element_of();
for (int i = 0; i < lanes; ++i) {
std::ostringstream value_temp;
if (!HandleTypeMatch(buffer_var.get(), elem_type)) {
value_temp << "((";
if (buffer_var.get()->dtype.is_handle()) {
auto it = alloc_storage_scope_.find(buffer_var.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, value_temp);
}
}
PrintType(elem_type, value_temp);
value_temp << "*)" << vid << ')';
} else {
value_temp << vid;
}
value_temp << '[';
PrintVecElemLoad(sindex, index.dtype(), i, value_temp);
value_temp << ']';
PrintVecElemLoadExpr(op->dtype, i, value_temp.str(), svalue_expr);
}
os << svalue_expr.str();
}
}
}
void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op, void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op,
std::ostream &os) { // NOLINT(*) std::ostream &os) { // NOLINT(*)
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value); int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
......
...@@ -50,6 +50,7 @@ public: ...@@ -50,6 +50,7 @@ public:
void VisitStmt_(const EvaluateNode *op) final; void VisitStmt_(const EvaluateNode *op) final;
void VisitStmt_(const AllocateNode *op) final; void VisitStmt_(const AllocateNode *op) final;
void VisitStmt_(const AttrStmtNode *op) final; void VisitStmt_(const AttrStmtNode *op) final;
void VisitExpr_(const BufferLoadNode *op, std::ostream &os) final;
// Override this as a work around for __grid_constant__ parameter // Override this as a work around for __grid_constant__ parameter
void AddFunction(const GlobalVar &gvar, const PrimFunc &f); void AddFunction(const GlobalVar &gvar, const PrimFunc &f);
......
...@@ -22,7 +22,8 @@ struct MinOp { ...@@ -22,7 +22,8 @@ struct MinOp {
} }
}; };
template <class Reducer, int threads, int scale, int thread_offset = 0> struct AllReduce { template <class Reducer, int threads, int scale, int thread_offset = 0>
struct AllReduce {
static_assert(threads == 1024 || threads == 512 || threads == 256 || static_assert(threads == 1024 || threads == 512 || threads == 256 ||
threads == 128 || threads == 64 || threads == 32 || threads == 128 || threads == 64 || threads == 32 ||
threads == 16 || threads == 8 || threads == 4 || threads == 2); threads == 16 || threads == 8 || threads == 4 || threads == 2);
......
...@@ -136,11 +136,23 @@ private: ...@@ -136,11 +136,23 @@ private:
max_vector_size = gcd_base; max_vector_size = gcd_base;
} }
vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_); vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);
// Generate strides if not existed
auto strides = buffer->strides;
if (buffer->strides.size() == 0) {
PrimExpr stride = 1;
for (int i = indices.size() - 1; i >= 0; --i) {
strides.push_back(stride);
stride = stride * buffer->shape[i];
}
strides = Array<PrimExpr>{strides.rbegin(), strides.rend()};
}
// Generate and check element offset expression
ICHECK(indices.size() == strides.size()) << "Invalid indices and strides";
PrimExpr elem_offset = 0; PrimExpr elem_offset = 0;
PrimExpr stride = 1; for (int i = 0; i < indices.size(); ++i) {
for (int i = indices.size() - 1; i >= 0; --i) { elem_offset += indices[i] * strides[i];
elem_offset = elem_offset + indices[i] * stride;
stride = stride * buffer->shape[i];
} }
while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var, while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
inner_for_->extent, vector_size_, inner_for_->extent, vector_size_,
...@@ -229,10 +241,19 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size, ...@@ -229,10 +241,19 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size,
ICHECK(target_vectorized_size >= 1); ICHECK(target_vectorized_size >= 1);
if (target_vectorized_size == 1) if (target_vectorized_size == 1)
return true; return true;
// bind thread range
// Extent must be divisible
if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_vectorized_size), if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_vectorized_size),
0)) 0))
return false; return false;
// The base offset must be divisible
if (!analyzer->CanProveEqual(
FloorMod(Substitute(expr, {{var, 0}}), target_vectorized_size), 0)) {
return false;
}
// Bind thread range
Var v0("v0"), v1("v1"); Var v0("v0"), v1("v1");
analyzer->Bind(v0, Range(0, target_vectorized_size)); analyzer->Bind(v0, Range(0, target_vectorized_size));
analyzer->Bind(v1, Range(0, analyzer->Simplify(FloorDiv( analyzer->Bind(v1, Range(0, analyzer->Simplify(FloorDiv(
...@@ -241,7 +262,8 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size, ...@@ -241,7 +262,8 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size,
Substitute(expr, {{var, v0 + v1 * target_vectorized_size}})); Substitute(expr, {{var, v0 + v1 * target_vectorized_size}}));
Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size)); Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size));
PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed); PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed);
// This simplify is necessary for thread region specifiled
// This simplify is necessary for thread region specified
// optimizations. // optimizations.
expr_vectorized = analyzer->Simplify(expr_vectorized); expr_vectorized = analyzer->Simplify(expr_vectorized);
auto ramp_node = expr_vectorized.as<RampNode>(); auto ramp_node = expr_vectorized.as<RampNode>();
......
...@@ -28,8 +28,8 @@ def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16") ...@@ -28,8 +28,8 @@ def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16")
out_idx=[1], out_idx=[1],
target="cuda", target="cuda",
pass_configs={ pass_configs={
"tl.disable_warp_specialized": True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
"tl.disable_tma_lower": True tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True
}) })
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a) b = kernel(a)
...@@ -42,5 +42,49 @@ def test_tilelang_copy(): ...@@ -42,5 +42,49 @@ def test_tilelang_copy():
run_tilelang_copy(M=1024, N=576, block_M=32, block_N=576, dtype="float") run_tilelang_copy(M=1024, N=576, block_M=32, block_N=576, dtype="float")
def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype="float16"):
@T.prim_func
def main(
A: T.StridedTensor((M, N), (NN, 1), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = A[by * block_M + i, bx * block_N + j]
return main
def run_tilelang_copy_with_stride(M=1024,
N=1024,
NN=2048,
block_M=128,
block_N=128,
dtype="float16"):
if isinstance(NN, int):
assert NN > N, "NN must be greater than N"
program = tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype)
kernel = tilelang.compile(
program,
out_idx=[1],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
})
if isinstance(NN, T.Var):
NN = N * 2
a = torch.randn(M, NN, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a[:, :N])
torch.testing.assert_close(b, a[:, :N], rtol=1e-2, atol=1e-2)
def test_tilelang_copy_with_stride():
run_tilelang_copy_with_stride(M=1024, N=1024, NN=2048, block_M=128, block_N=128)
run_tilelang_copy_with_stride(M=1024, N=1024, NN=T.symbolic("NN"), block_M=128, block_N=128)
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -9,6 +9,7 @@ def _check(original, transformed): ...@@ -9,6 +9,7 @@ def _check(original, transformed):
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
mod = tl.transform.InjectSoftwarePipeline()(mod) mod = tl.transform.InjectSoftwarePipeline()(mod)
mod = tl.transform.Simplify()(mod) mod = tl.transform.Simplify()(mod)
mod = tl.transform.LowerOpaqueBlock()(mod)
tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"),
True) True)
...@@ -39,32 +40,16 @@ def test_trival_pipeline(): ...@@ -39,32 +40,16 @@ def test_trival_pipeline():
C[tx, i] = B[tx, 0] + T.float32(1) C[tx, i] = B[tx, 0] + T.float32(1)
@T.prim_func @T.prim_func
def expected(A: T.Buffer((16, 1), "float32"), C: T.Buffer((16, 1), "float32")): def expected(A_handle: T.handle, C_handle: T.handle):
for tx in T.thread_binding(16, thread="threadIdx.x"): A = T.match_buffer(A_handle, (16, 1), strides=(1, 1))
with T.block(): C = T.match_buffer(C_handle, (16, 1), strides=(1, 1))
T.reads(A[tx, 0]) tx = T.launch_thread("threadIdx.x", 16)
T.writes(C[tx, 0]) B = T.decl_buffer((2, 16, 1), scope="shared")
B = T.alloc_buffer((2, 16, 1), scope="shared") B[0, tx, 0] = A[tx, 0] * T.float32(2.0)
with T.block(): for i in range(0):
T.reads(A[tx, 0]) B[i + 1, tx, 0] = A[tx, i + 1] * T.float32(2.0)
T.writes(B[0, tx, 0]) C[tx, i] = B[i, tx, 0] + T.float32(1.0)
B[0, tx, 0] = A[tx, 0] * T.float32(2.0) C[tx, 0] = B[0, tx, 0] + T.float32(1.0)
with T.block():
T.reads(A[tx, 1:1], B[0:2, tx, 0])
T.writes(B[1:1, tx, 0], C[tx, 0:0])
for i in range(0):
with T.block():
T.reads(A[tx, i + 1])
T.writes(B[i + 1, tx, 0])
B[i + 1, tx, 0] = A[tx, i + 1] * T.float32(2.0)
with T.block():
T.reads(B[i, tx, 0])
T.writes(C[tx, i])
C[tx, i] = B[i, tx, 0] + T.float32(1.0)
with T.block():
T.reads(B[0, tx, 0])
T.writes(C[tx, 0])
C[tx, 0] = B[0, tx, 0] + T.float32(1.0)
_check(before, expected) _check(before, expected)
......
...@@ -124,8 +124,10 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -124,8 +124,10 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.InjectFenceProxy()(mod) mod = tilelang.transform.InjectFenceProxy()(mod)
mod = tilelang.transform.LowerOpaqueBlock()(mod) mod = tilelang.transform.LowerOpaqueBlock()(mod)
mod = tir.transform.NarrowDataType(32)(mod) mod = tir.transform.NarrowDataType(32)(mod)
mod = tilelang.transform.ConfigIndexBitwidth()(mod)
mod = tilelang.transform.FlattenBuffer()(mod) mod = tilelang.transform.FlattenBuffer()(mod)
# ConfigIndexBitwidth must be applied after FlattenBuffer
# as it will flatten index computing
mod = tilelang.transform.ConfigIndexBitwidth()(mod)
mod = tir.transform.Simplify()(mod) mod = tir.transform.Simplify()(mod)
mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod) mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod)
......
...@@ -155,21 +155,31 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -155,21 +155,31 @@ class CtypesKernelAdapter(BaseKernelAdapter):
adapter._post_init() adapter._post_init()
return adapter return adapter
def _process_dynamic_symbolic(self): def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int, int]]:
"""Extract information about dynamic shapes from the TIR function. """Extract information about dynamic shapes from the TIR function.
Maps symbolic variables to their corresponding (buffer_index, shape_dimension) Maps symbolic variables to their corresponding (id, buffer_index, dimension)
for runtime shape resolution. for runtime shape resolution.
id represents shape or stride, 0 represents shape, 1 represents stride
""" """
func = self.prim_func func = self.prim_func
params = func.params params = func.params
buffer_map = func.buffer_map buffer_map = func.buffer_map
dynamic_symbolic_map = {} dynamic_symbolic_map = {}
for i, param in enumerate(params): for i, param in enumerate(params):
buffer = buffer_map[param] if param in buffer_map:
for j, shape in enumerate(buffer.shape): buffer = buffer_map[param]
if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map): for j, shape in enumerate(buffer.shape):
dynamic_symbolic_map[shape] = (i, j) if (isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and
(shape not in params)):
dynamic_symbolic_map[shape] = (0, i, j)
for i, param in enumerate(params):
if param in buffer_map:
buffer = buffer_map[param]
for j, stride in enumerate(buffer.strides):
if (isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and
(stride not in params)):
dynamic_symbolic_map[stride] = (1, i, j)
return dynamic_symbolic_map return dynamic_symbolic_map
def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None): def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None):
...@@ -228,8 +238,11 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -228,8 +238,11 @@ class CtypesKernelAdapter(BaseKernelAdapter):
args.append(tensor) args.append(tensor)
# dynamic symbolics # dynamic symbolics
for _, (buffer_idx, shape_idx) in self.dynamic_symbolic_map.items(): for _, (ref_id, buffer_idx, shape_idx) in self.dynamic_symbolic_map.items():
args.append(ins[buffer_idx].shape[shape_idx]) if ref_id == 0:
args.append(ins[buffer_idx].shape[shape_idx])
else:
args.append(ins[buffer_idx].stride(shape_idx))
# if stream is not None, we need to pass the stream to the library # if stream is not None, we need to pass the stream to the library
if stream is None: if stream is None:
......
"""The profiler and convert to torch utils""" """The profiler and convert to torch utils"""
from ..base import BaseKernelAdapter
import ctypes import ctypes
import fcntl
import hashlib
import logging
import site
import sys
import sysconfig
import torch
import os
from pathlib import Path
from typing import List, Optional, Union, Callable, Dict, Tuple, Any from typing import List, Optional, Union, Callable, Dict, Tuple, Any
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.target import Target from tvm.target import Target
from tilelang.engine.param import KernelParam from tilelang.engine.param import KernelParam
from tvm import tir from tvm import tir
from tvm.relax import TensorType from tvm.relax import TensorType
from tilelang.jit.adapter.base import BaseKernelAdapter
from tilelang.jit.adapter.wrapper import TLWrapper from tilelang.jit.adapter.wrapper import TLWrapper
from tilelang.jit.adapter.libgen import LibraryGenerator from tilelang.jit.adapter.libgen import LibraryGenerator
from tilelang.jit.adapter.utils import is_cuda_target, is_hip_target, is_cpu_target from tilelang.jit.adapter.utils import is_cuda_target, is_hip_target, is_cpu_target
...@@ -15,15 +26,6 @@ from tilelang.utils.target import determine_target ...@@ -15,15 +26,6 @@ from tilelang.utils.target import determine_target
from tilelang.utils.language import retrieve_func_from_module from tilelang.utils.language import retrieve_func_from_module
from tilelang.utils.tensor import map_torch_type from tilelang.utils.tensor import map_torch_type
from tilelang.contrib.cc import get_cplus_compiler from tilelang.contrib.cc import get_cplus_compiler
import torch
import sys
import sysconfig
import hashlib
import os
import fcntl
from pathlib import Path
import logging
import site
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -116,15 +118,15 @@ with open(cython_wrapper_path, "r") as f: ...@@ -116,15 +118,15 @@ with open(cython_wrapper_path, "r") as f:
with open(md5_path, "r") as f: with open(md5_path, "r") as f:
cached_hash = f.read().strip() cached_hash = f.read().strip()
if cached_hash == code_hash: if cached_hash == code_hash:
logger.debug("Cython jit adapter is up to date, no need to compile...") logger.debug("Cython JIT adapter is up to date, no need to compile...")
need_compile = False need_compile = False
else: else:
logger.info("Cython jit adapter is out of date, need to recompile...") logger.info("Cython JIT adapter is out of date, need to recompile...")
else: else:
logger.info("No cached version found for cython jit adapter, need to compile...") logger.info("No cached version found for Cython JIT adapter, need to compile...")
if need_compile: if need_compile:
logger.info("Waiting for lock to compile cython jit adapter...") logger.info("Waiting for lock to compile Cython JIT adapter...")
with open(lock_file, 'w') as lock: with open(lock_file, 'w') as lock:
fcntl.flock(lock.fileno(), fcntl.LOCK_EX) fcntl.flock(lock.fileno(), fcntl.LOCK_EX)
try: try:
...@@ -138,7 +140,7 @@ with open(cython_wrapper_path, "r") as f: ...@@ -138,7 +140,7 @@ with open(cython_wrapper_path, "r") as f:
need_compile = False need_compile = False
if need_compile: if need_compile:
logger.info("Compiling cython jit adapter...") logger.info("Compiling Cython JIT adapter...")
temp_path = cache_dir / f"temp_{code_hash}.so" temp_path = cache_dir / f"temp_{code_hash}.so"
with open(md5_path, "w") as f: with open(md5_path, "w") as f:
...@@ -159,7 +161,7 @@ with open(cython_wrapper_path, "r") as f: ...@@ -159,7 +161,7 @@ with open(cython_wrapper_path, "r") as f:
except Exception as e: except Exception as e:
if 'temp_path' in locals() and temp_path.exists(): if 'temp_path' in locals() and temp_path.exists():
temp_path.unlink() temp_path.unlink()
raise Exception(f"Failed to compile cython jit adapter: {e}") from e raise Exception(f"Failed to compile Cython JIT adapter: {e}") from e
finally: finally:
if lock_file.exists(): if lock_file.exists():
lock_file.unlink() lock_file.unlink()
...@@ -195,11 +197,14 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -195,11 +197,14 @@ class CythonKernelAdapter(BaseKernelAdapter):
ptr_map: Optional[Dict[int, str]] = None ptr_map: Optional[Dict[int, str]] = None
# Maps buffer variables to their corresponding dtypes # Maps buffer variables to their corresponding dtypes
buffer_dtype_map: Optional[Dict[tir.Var, Tuple[int, torch.dtype]]] = None buffer_dtype_map: Optional[Dict[tir.Var, Tuple[int, torch.dtype]]] = None
# Maps buffer variables to their corresponding static shapes # Maps buffer variables to their corresponding static shapes and strides,
# { # e.g., {
# "A": [(0, 16), (1, 16)] -> represents A.shape = (16, 16) # "A": [(0, 16), (1, 16)] -> represents A.shape/strides = (16, 16)
# } # }
static_shape_map: Optional[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]]] = None static_shape_map: Optional[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]]] = None
static_strides_map: Optional[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]]] = None
# Contains contiguous buffers
static_contiguous_list: Optional[List[tir.Var]] = None
# Maps buffer variables to their corresponding devices # Maps buffer variables to their corresponding devices
buffer_device_map: Optional[Dict[tir.Var, Tuple[int, torch.device]]] = None buffer_device_map: Optional[Dict[tir.Var, Tuple[int, torch.device]]] = None
# Pass configs for the compiler # Pass configs for the compiler
...@@ -239,9 +244,13 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -239,9 +244,13 @@ class CythonKernelAdapter(BaseKernelAdapter):
self.dynamic_symbolic_map = self._process_dynamic_symbolic() self.dynamic_symbolic_map = self._process_dynamic_symbolic()
self.buffer_dtype_map = self._process_buffer_dtype() self.buffer_dtype_map = self._process_buffer_dtype()
self.ptr_map = self._process_ptr_map() self.ptr_map = self._process_ptr_map()
self.static_shape_map = self._process_static_shape()
self.buffer_device_map = self._process_buffer_device() self.buffer_device_map = self._process_buffer_device()
static_buffer_infos = self._process_static_buffer_infos()
self.static_shape_map = static_buffer_infos[0]
self.static_strides_map = static_buffer_infos[1]
self.static_contiguous_list = static_buffer_infos[2]
self.verbose = verbose self.verbose = verbose
self.wrapper = TLWrapper(self.target) self.wrapper = TLWrapper(self.target)
self.lib_generator = LibraryGenerator(self.target, verbose=verbose) self.lib_generator = LibraryGenerator(self.target, verbose=verbose)
...@@ -269,6 +278,8 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -269,6 +278,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
self.cython_wrapper.set_dynamic_symbolic_map(self.dynamic_symbolic_map) self.cython_wrapper.set_dynamic_symbolic_map(self.dynamic_symbolic_map)
self.cython_wrapper.set_buffer_dtype_map(self.buffer_dtype_map) self.cython_wrapper.set_buffer_dtype_map(self.buffer_dtype_map)
self.cython_wrapper.set_static_shape_map(self.static_shape_map) self.cython_wrapper.set_static_shape_map(self.static_shape_map)
self.cython_wrapper.set_static_strides_map(self.static_strides_map)
self.cython_wrapper.set_static_contiguous_list(self.static_contiguous_list)
self.cython_wrapper.set_buffer_device_map(self.buffer_device_map) self.cython_wrapper.set_buffer_device_map(self.buffer_device_map)
self.cython_wrapper.set_ptr_map(self.ptr_map) self.cython_wrapper.set_ptr_map(self.ptr_map)
self._post_init() self._post_init()
...@@ -301,10 +312,14 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -301,10 +312,14 @@ class CythonKernelAdapter(BaseKernelAdapter):
adapter.dynamic_symbolic_map = adapter._process_dynamic_symbolic() adapter.dynamic_symbolic_map = adapter._process_dynamic_symbolic()
adapter.buffer_dtype_map = adapter._process_buffer_dtype() adapter.buffer_dtype_map = adapter._process_buffer_dtype()
adapter.static_shape_map = adapter._process_static_shape()
adapter.ptr_map = adapter._process_ptr_map() adapter.ptr_map = adapter._process_ptr_map()
adapter.buffer_device_map = adapter._process_buffer_device() adapter.buffer_device_map = adapter._process_buffer_device()
static_buffer_infos = adapter._process_static_buffer_infos()
adapter.static_shape_map = static_buffer_infos[0]
adapter.static_strides_map = static_buffer_infos[1]
adapter.static_contiguous_list = static_buffer_infos[2]
adapter.verbose = verbose adapter.verbose = verbose
adapter.lib_generator = LibraryGenerator(adapter.target, verbose=verbose) adapter.lib_generator = LibraryGenerator(adapter.target, verbose=verbose)
adapter.lib_generator.assign_pass_configs(pass_configs) adapter.lib_generator.assign_pass_configs(pass_configs)
...@@ -322,17 +337,20 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -322,17 +337,20 @@ class CythonKernelAdapter(BaseKernelAdapter):
adapter.cython_wrapper.set_dynamic_symbolic_map(adapter.dynamic_symbolic_map) adapter.cython_wrapper.set_dynamic_symbolic_map(adapter.dynamic_symbolic_map)
adapter.cython_wrapper.set_buffer_dtype_map(adapter.buffer_dtype_map) adapter.cython_wrapper.set_buffer_dtype_map(adapter.buffer_dtype_map)
adapter.cython_wrapper.set_static_shape_map(adapter.static_shape_map) adapter.cython_wrapper.set_static_shape_map(adapter.static_shape_map)
adapter.cython_wrapper.set_static_strides_map(adapter.static_strides_map)
adapter.cython_wrapper.set_static_contiguous_list(adapter.static_contiguous_list)
adapter.cython_wrapper.set_buffer_device_map(adapter.buffer_device_map) adapter.cython_wrapper.set_buffer_device_map(adapter.buffer_device_map)
adapter.cython_wrapper.set_ptr_map(adapter.ptr_map) adapter.cython_wrapper.set_ptr_map(adapter.ptr_map)
adapter._post_init() adapter._post_init()
return adapter return adapter
def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int]]: def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int, int]]:
"""Extract information about dynamic shapes from the TIR function. """Extract information about dynamic shapes from the TIR function.
Maps symbolic variables to their corresponding (buffer_index, shape_dimension) Maps symbolic variables to their corresponding (id, buffer_index, dimension)
for runtime shape resolution. for runtime shape resolution.
id represents shape or stride, 0 represents shape, 1 represents stride
""" """
func = self.prim_func func = self.prim_func
params = func.params params = func.params
...@@ -344,7 +362,14 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -344,7 +362,14 @@ class CythonKernelAdapter(BaseKernelAdapter):
for j, shape in enumerate(buffer.shape): for j, shape in enumerate(buffer.shape):
if (isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and if (isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and
(shape not in params)): (shape not in params)):
dynamic_symbolic_map[shape] = (i, j) dynamic_symbolic_map[shape] = (0, i, j)
for i, param in enumerate(params):
if param in buffer_map:
buffer = buffer_map[param]
for j, stride in enumerate(buffer.strides):
if (isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and
(stride not in params)):
dynamic_symbolic_map[stride] = (1, i, j)
return dynamic_symbolic_map return dynamic_symbolic_map
def _process_buffer_dtype(self) -> Dict[tir.Var, Tuple[int, torch.dtype]]: def _process_buffer_dtype(self) -> Dict[tir.Var, Tuple[int, torch.dtype]]:
...@@ -377,7 +402,10 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -377,7 +402,10 @@ class CythonKernelAdapter(BaseKernelAdapter):
ptr_map[i] = param.name ptr_map[i] = param.name
return ptr_map return ptr_map
def _process_static_shape(self) -> Dict[tir.Var, List[Tuple[int, int]]]: def _process_static_buffer_infos(self) -> \
Tuple[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]],
Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]],
List[Tuple[tir.Var]]]:
"""Extract information about static shapes from the TIR function. """Extract information about static shapes from the TIR function.
Maps buffer variables to their corresponding static shapes. Maps buffer variables to their corresponding static shapes.
...@@ -386,17 +414,27 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -386,17 +414,27 @@ class CythonKernelAdapter(BaseKernelAdapter):
params = func.params params = func.params
buffer_map = func.buffer_map buffer_map = func.buffer_map
static_shape_map = {} static_shape_map = {}
static_strides_map = {}
static_contiguous_list = list()
for i, param in enumerate(params): for i, param in enumerate(params):
if param in buffer_map: if param in buffer_map:
buffer = buffer_map[param] buffer = buffer_map[param]
name = buffer.name static_shape, static_strides = [], []
shape = buffer.shape for j, s in enumerate(buffer.shape):
static_shape = []
for j, s in enumerate(shape):
if isinstance(s, tir.IntImm): if isinstance(s, tir.IntImm):
static_shape.append((j, s.value)) static_shape.append((j, s.value))
static_shape_map[name] = (i, static_shape) for j, s in enumerate(buffer.strides):
return static_shape_map if isinstance(s, tir.IntImm):
static_strides.append((j, s.value))
is_contiguous, prod = True, 1
for dim, stride in reversed(list(zip(buffer.shape, buffer.strides))):
is_contiguous &= bool(stride == prod)
prod *= dim
static_shape_map[buffer.name] = (i, static_shape)
static_strides_map[buffer.name] = (i, static_strides)
if is_contiguous:
static_contiguous_list.append((i, buffer.name))
return static_shape_map, static_strides_map, static_contiguous_list
def _process_buffer_device(self) -> Dict[tir.Var, Tuple[int, torch.device]]: def _process_buffer_device(self) -> Dict[tir.Var, Tuple[int, torch.device]]:
"""Extract information about buffer devices from the TIR function. """Extract information about buffer devices from the TIR function.
...@@ -473,7 +511,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -473,7 +511,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
@property @property
def is_dynamic(self): def is_dynamic(self):
"""Indicates whether the kernel handles dynamic shapes.""" """Indicates whether the kernel handles dynamic shapes."""
return (self.dynamic_symbolic_map is not None and len(self.dynamic_symbolic_map) > 0) return self.dynamic_symbolic_map is not None and len(self.dynamic_symbolic_map) > 0
def get_kernel_source(self, kernel_only: bool = False): def get_kernel_source(self, kernel_only: bool = False):
"""Returns the source code of the compiled kernel.""" """Returns the source code of the compiled kernel."""
......
...@@ -11,17 +11,19 @@ from tilelang.utils.tensor import map_torch_type ...@@ -11,17 +11,19 @@ from tilelang.utils.tensor import map_torch_type
cdef class CythonKernelWrapper: cdef class CythonKernelWrapper:
# Class attributes to store kernel configuration and library reference # Class attributes to store kernel configuration and library reference
cdef: cdef:
object dynamic_symbolic_map # Maps dynamic dimensions to their corresponding tensor indices object dynamic_symbolic_map # Maps dynamic dimensions to their corresponding tensor indices
object buffer_device_map # Maps buffer variables to their corresponding devices object buffer_device_map # Maps buffer variables to their corresponding devices
object buffer_dtype_map # Maps buffer variables to their corresponding dtypes object buffer_dtype_map # Maps buffer variables to their corresponding dtypes
object static_shape_map # Maps buffer variables to their corresponding static shapes object static_shape_map # Maps buffer variables to their corresponding static shapes
object ptr_map # Maps pointer arguments to their corresponding buffer indices object static_strides_map # Maps buffer variables to their corresponding static strides
list result_idx # Indices of output tensors in the params list object static_contiguous_list # A list contains contiguous buffers
list params # List of parameter specifications (includes both inputs and outputs) object ptr_map # Maps pointer arguments to their corresponding buffer indices
object lib # Reference to the compiled library containing the kernel list result_idx # Indices of output tensors in the params list
list params # List of parameter specifications (includes both inputs and outputs)
object lib # Reference to the compiled library containing the kernel
# Add new cache attributes # Add new cache attributes
list param_dtypes # Cache for parameter dtypes list param_dtypes # Cache for parameter dtypes
list param_shapes # Cache for parameter shapes as native Python lists list param_shapes # Cache for parameter shapes as native Python lists
object get_current_device object get_current_device
def __cinit__(self, result_idx, params, lib): def __cinit__(self, result_idx, params, lib):
...@@ -57,6 +59,14 @@ cdef class CythonKernelWrapper: ...@@ -57,6 +59,14 @@ cdef class CythonKernelWrapper:
self.static_shape_map = static_shape_map self.static_shape_map = static_shape_map
return self return self
def set_static_strides_map(self, static_strides_map):
self.static_strides_map = static_strides_map
return self
def set_static_contiguous_list(self, static_contiguous_list):
self.static_contiguous_list = static_contiguous_list
return self
def set_ptr_map(self, ptr_map): def set_ptr_map(self, ptr_map):
self.ptr_map = ptr_map self.ptr_map = ptr_map
return self return self
...@@ -94,15 +104,41 @@ cdef class CythonKernelWrapper: ...@@ -94,15 +104,41 @@ cdef class CythonKernelWrapper:
cpdef void _check_static_shape(self, list tensor_list): cpdef void _check_static_shape(self, list tensor_list):
for param, (buffer_idx, shape_list) in self.static_shape_map.items(): for param, (buffer_idx, shape_list) in self.static_shape_map.items():
tensor = tensor_list[buffer_idx] tensor = tensor_list[buffer_idx]
if isinstance(tensor, torch.Tensor): if not isinstance(tensor, torch.Tensor):
for shape_idx, expected_shape in shape_list: # otherwise, maybe torch.data_ptr() for T.ptr inputs
actual_shape = tensor.shape[shape_idx] continue
if actual_shape != expected_shape: for shape_idx, expected_shape in shape_list:
raise ValueError( actual_shape = tensor.shape[shape_idx]
f"Static shape mismatch for parameter {param}: " if actual_shape != expected_shape:
f"expected {expected_shape} at index {shape_idx}, " raise ValueError(
f"got {actual_shape}" f"Static shape mismatch for parameter {param}: "
) f"expected {expected_shape} at index {shape_idx}, "
f"got {actual_shape}"
)
cpdef void _check_static_strides(self, list tensor_list):
for param, (buffer_idx, strides_list) in self.static_strides_map.items():
tensor = tensor_list[buffer_idx]
if not isinstance(tensor, torch.Tensor):
# otherwise, maybe torch.data_ptr() for T.ptr inputs
continue
for stride_idx, expected_stride in strides_list:
actual_stride = tensor.stride(stride_idx)
if actual_stride != expected_stride:
raise ValueError(
f"Static stride mismatch for parameter {param}: "
f"expected {expected_stride} at index {stride_idx}, "
f"got {actual_stride}"
)
cpdef void _check_static_contiguous(self, list tensor_list):
for buffer_idx, param in self.static_contiguous_list:
tensor = tensor_list[buffer_idx]
if not isinstance(tensor, torch.Tensor):
# otherwise, maybe torch.data_ptr() for T.ptr inputs
continue
if not tensor.is_contiguous():
raise ValueError(f"Expected parameter {param} to be a contiguous tensor")
cpdef forward(self, list inputs, int64_t stream = -1, bint skip_tensor_validation = False): cpdef forward(self, list inputs, int64_t stream = -1, bint skip_tensor_validation = False):
# Validate input dimensions and prepare for kernel execution # Validate input dimensions and prepare for kernel execution
...@@ -140,7 +176,7 @@ cdef class CythonKernelWrapper: ...@@ -140,7 +176,7 @@ cdef class CythonKernelWrapper:
if isinstance(s, tir.Var): if isinstance(s, tir.Var):
for key in self.dynamic_symbolic_map: for key in self.dynamic_symbolic_map:
if(str(s) == str(key)): if(str(s) == str(key)):
ref_tensor_idx, ref_shape_idx = self.dynamic_symbolic_map[key] ref_id, ref_tensor_idx, ref_shape_idx = self.dynamic_symbolic_map[key]
shape.append(tensor_list[ref_tensor_idx].shape[ref_shape_idx]) shape.append(tensor_list[ref_tensor_idx].shape[ref_shape_idx])
else: # Already converted to Python int during initialization else: # Already converted to Python int during initialization
shape.append(s) shape.append(s)
...@@ -155,6 +191,13 @@ cdef class CythonKernelWrapper: ...@@ -155,6 +191,13 @@ cdef class CythonKernelWrapper:
else: else:
tensor = inputs[ins_idx] tensor = inputs[ins_idx]
ins_idx += 1 ins_idx += 1
# TODO(chenggang): remove this check or rewrite by ourselves?
if isinstance(tensor, torch.Tensor) and tensor._base is not None and not tensor.is_contiguous():
base_tensor = tensor._base.as_strided(tensor._base.shape, tensor.stride())
if torch._debug_has_internal_overlap(base_tensor):
raise ValueError(f"Cannot use an overlapping tensor"
f"(shape={tensor.shape}, strides={tensor.stride()}, "
f"overlap={torch._debug_has_internal_overlap(base_tensor)}) as the kernel input")
tensor_list.append(tensor) tensor_list.append(tensor)
# Convert tensor pointers to C void pointers for kernel call # Convert tensor pointers to C void pointers for kernel call
...@@ -172,8 +215,6 @@ cdef class CythonKernelWrapper: ...@@ -172,8 +215,6 @@ cdef class CythonKernelWrapper:
call_args = [] call_args = []
for i, tensor in enumerate(tensor_list): for i, tensor in enumerate(tensor_list):
if isinstance(tensor, torch.Tensor): if isinstance(tensor, torch.Tensor):
if not tensor.is_contiguous():
raise ValueError(f"Input tensor at index {i} must be contiguous")
call_args.append(ctypes.c_void_p(tensor.data_ptr())) call_args.append(ctypes.c_void_p(tensor.data_ptr()))
elif isinstance(tensor, (int, float, bool)): elif isinstance(tensor, (int, float, bool)):
if i in self.ptr_map: if i in self.ptr_map:
...@@ -191,10 +232,15 @@ cdef class CythonKernelWrapper: ...@@ -191,10 +232,15 @@ cdef class CythonKernelWrapper:
self._check_buffer_device(tensor_list) self._check_buffer_device(tensor_list)
self._check_buffer_dtype(tensor_list) self._check_buffer_dtype(tensor_list)
self._check_static_shape(tensor_list) self._check_static_shape(tensor_list)
self._check_static_strides(tensor_list)
self._check_static_contiguous(tensor_list)
# Add dynamic dimension values to kernel arguments # Add dynamic dimension values to kernel arguments
for _, (buffer_idx, shape_idx) in self.dynamic_symbolic_map.items(): for _, (ref_id, buffer_idx, shape_idx) in self.dynamic_symbolic_map.items():
call_args.append(tensor_list[buffer_idx].shape[shape_idx]) if ref_id == 0:
call_args.append(tensor_list[buffer_idx].shape[shape_idx])
else:
call_args.append(tensor_list[buffer_idx].stride(shape_idx))
# Add CUDA stream to kernel arguments # Add CUDA stream to kernel arguments
call_args.append(ctypes.c_void_p(stream)) call_args.append(ctypes.c_void_p(stream))
......
...@@ -234,7 +234,10 @@ class TLCUDASourceWrapper(object): ...@@ -234,7 +234,10 @@ class TLCUDASourceWrapper(object):
dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func) dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func)
function_args = [] function_args = []
# Collect function arguments based on primary function's parameters and buffer mappings # Collect function arguments based on primary function's parameters and buffer mappings
# QA(@lei): Why not use device_mod.params?
# device func lack buffer map (to convert buffer handle to buffer)
for param in self.prim_func.params: for param in self.prim_func.params:
if param in self.prim_func.buffer_map: if param in self.prim_func.buffer_map:
buffer = self.prim_func.buffer_map[param] buffer = self.prim_func.buffer_map[param]
...@@ -484,12 +487,26 @@ class TLCUDASourceWrapper(object): ...@@ -484,12 +487,26 @@ class TLCUDASourceWrapper(object):
def get_dynamic_symbolic_set(self, prim_func): def get_dynamic_symbolic_set(self, prim_func):
# Determine the set of dynamic symbols used in the function # Determine the set of dynamic symbols used in the function
dynamic_symbolic_set: List[str] = [] dynamic_symbolic_set: List[str] = []
def unique_push_back(name: str):
if name not in dynamic_symbolic_set:
dynamic_symbolic_set.append(name)
for param in prim_func.params: for param in prim_func.params:
if param in prim_func.buffer_map: if param in prim_func.buffer_map:
buffer = prim_func.buffer_map[param] buffer = prim_func.buffer_map[param]
for dim in buffer.shape: for dim in buffer.shape:
if isinstance(dim, tvm.tir.Var) and (dim.name not in dynamic_symbolic_set): if isinstance(dim, tvm.tir.Var):
dynamic_symbolic_set.append(dim.name) unique_push_back(dim.name)
# Note: In buffer definitions, any dynamic symbols appearing in strides are listed after those in the shape.
for param in prim_func.params:
if param in prim_func.buffer_map:
buffer = prim_func.buffer_map[param]
for stride in buffer.strides:
if isinstance(stride, tvm.tir.Var):
unique_push_back(stride.name)
return dynamic_symbolic_set return dynamic_symbolic_set
def get_init_func(self): def get_init_func(self):
...@@ -549,6 +566,19 @@ class TLCUDASourceWrapper(object): ...@@ -549,6 +566,19 @@ class TLCUDASourceWrapper(object):
return function return function
raise ValueError("Cannot find primary function in the module.") raise ValueError("Cannot find primary function in the module.")
@property
def device_func(self):
if len(self.device_mod.get_global_vars()) == 1:
return self.device_mod[self.device_mod.get_global_vars()[0]]
elif "main" in self.device_mod:
return self.device_mod["main"]
else:
for _, function in self.device_mod.functions.items():
attr = function.attrs
if "tir.is_global_func" in attr and attr["tir.is_global_func"]:
return function
raise ValueError("Cannot find primary function in the module.")
class TLNVRTCSourceWrapper(TLCUDASourceWrapper): class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
""" """
......
...@@ -17,6 +17,7 @@ from .proxy import ( ...@@ -17,6 +17,7 @@ from .proxy import (
make_tensor, # noqa: F401 make_tensor, # noqa: F401
Buffer, # noqa: F401 Buffer, # noqa: F401
Tensor, # noqa: F401 Tensor, # noqa: F401
StridedTensor, # noqa: F401
FragmentBuffer, # noqa: F401 FragmentBuffer, # noqa: F401
SharedBuffer, # noqa: F401 SharedBuffer, # noqa: F401
LocalBuffer, # noqa: F401 LocalBuffer, # noqa: F401
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations from __future__ import annotations
from typing import Any, Optional, Sequence, SupportsIndex, TYPE_CHECKING from typing import Any, Optional, Sequence, SupportsIndex, TYPE_CHECKING, Tuple, Union
from typing_extensions import Self from typing_extensions import Self
from tvm import tir from tvm import tir
...@@ -53,7 +53,8 @@ class BufferProxy: ...@@ -53,7 +53,8 @@ class BufferProxy:
def from_ptr(self, def from_ptr(self,
pointer_var: Var, pointer_var: Var,
shape: tuple[PrimExpr, ...], shape: tuple[PrimExpr, ...],
dtype: str = "float32") -> Buffer: dtype: str = "float32",
strides: tuple[PrimExpr, ...] = None) -> Buffer:
"""Create a buffer from a pointer, shape, and data type. """Create a buffer from a pointer, shape, and data type.
Args: Args:
...@@ -64,7 +65,7 @@ class BufferProxy: ...@@ -64,7 +65,7 @@ class BufferProxy:
Returns: Returns:
A buffer created from the given parameters A buffer created from the given parameters
""" """
return match_buffer(pointer_var, shape, dtype=dtype) return match_buffer(pointer_var, shape, dtype=dtype, strides=strides)
class BaseTensorProxy: class BaseTensorProxy:
...@@ -110,16 +111,17 @@ class BaseTensorProxy: ...@@ -110,16 +111,17 @@ class BaseTensorProxy:
) )
def __getitem__(self, keys) -> tir.Buffer: def __getitem__(self, keys) -> tir.Buffer:
if not isinstance(keys, tuple): assert isinstance(keys, tuple)
return self(keys) # Single argument (the shape)
if len(keys) >= 2 and not isinstance(keys[1], str): if all([type(s) not in (tuple, str, list) for s in keys]):
return self(keys) keys = (keys,)
return self(*keys) return self(*keys)
def from_ptr(self, def from_ptr(self,
pointer_var: Var, pointer_var: Var,
shape: tuple[PrimExpr, ...], shape: tuple[PrimExpr, ...],
dtype: str = "float32") -> tir.Buffer: dtype: str = "float32",
strides: tuple[PrimExpr, ...] = None) -> tir.Buffer:
"""Create a buffer from a pointer, shape, and data type. """Create a buffer from a pointer, shape, and data type.
Args: Args:
...@@ -130,16 +132,51 @@ class BaseTensorProxy: ...@@ -130,16 +132,51 @@ class BaseTensorProxy:
Returns: Returns:
A buffer created from the given parameters A buffer created from the given parameters
""" """
return match_buffer(pointer_var, shape, dtype=dtype) return match_buffer(pointer_var, shape, dtype=dtype, strides=strides)
class TensorProxy(BaseTensorProxy): class TensorProxy(BaseTensorProxy):
"""Main tensor proxy class for global scope buffers. """Main tensor proxy class for global scope buffers.
This class implements the default tensor proxy with global memory scope, This class implements the default tensor proxy with global memory scope,
inheriting all functionality from BaseTensorProxy without modifications. the tensor should be by default contiguous.
""" """
@staticmethod
def _construct_strides(shape: Tuple[Any]):
s, strides = 1, [1]
for dim in shape[:0:-1]:
s *= dim
strides.append(s)
return tuple(reversed(strides))
def __call__(self,
shape: Union[Tuple[Any], PrimExpr, int],
dtype: str = "float32",
data=None) -> tir.Buffer:
if isinstance(shape, (int, PrimExpr)):
shape = (shape,)
return super().__call__(
shape, dtype=dtype, strides=TensorProxy._construct_strides(shape), data=data)
class StridedTensorProxy(BaseTensorProxy):
"""Main tensor proxy class for global scope buffers, with strides supported.
This class implements the default tensor proxy with global memory scope, with the stride information required.
"""
def __call__(self,
shape: Tuple[Any],
strides: Tuple[Any],
dtype: str = "float32") -> tir.Buffer:
if len(shape) != len(strides):
raise ValueError("Invalid shape/strides' dimensions")
if not bool(strides[-1] == 1):
# TODO(chenggang): shall we support non-contiguous even for the last dimension?
raise ValueError("The stride of the last dimension must be 1 (contiguous)")
return super().__call__(shape, dtype=dtype, strides=strides)
class FragmentBufferProxy(BaseTensorProxy): class FragmentBufferProxy(BaseTensorProxy):
"""Proxy class for fragment memory buffers. """Proxy class for fragment memory buffers.
...@@ -204,12 +241,16 @@ if TYPE_CHECKING: ...@@ -204,12 +241,16 @@ if TYPE_CHECKING:
def from_ptr(cls, def from_ptr(cls,
pointer_var: Var, pointer_var: Var,
shape: Sequence[PrimExpr, ...], shape: Sequence[PrimExpr, ...],
dtype: str = "float32") -> Self: dtype: str = "float32",
strides: tuple[PrimExpr, ...] = None) -> Self:
... ...
class Tensor(BaseTensor): class Tensor(BaseTensor):
... ...
class StridedTensor(BaseTensor):
...
class FragmentBuffer(BaseTensor): class FragmentBuffer(BaseTensor):
... ...
...@@ -220,6 +261,7 @@ if TYPE_CHECKING: ...@@ -220,6 +261,7 @@ if TYPE_CHECKING:
... ...
else: else:
Tensor = TensorProxy() # pylint: disable=invalid-name Tensor = TensorProxy() # pylint: disable=invalid-name
StridedTensor = StridedTensorProxy() # pylint: disable=invalid-name
FragmentBuffer = FragmentBufferProxy() # pylint: disable=invalid-name FragmentBuffer = FragmentBufferProxy() # pylint: disable=invalid-name
SharedBuffer = SharedBufferProxy() # pylint: disable=invalid-name SharedBuffer = SharedBufferProxy() # pylint: disable=invalid-name
LocalBuffer = LocalBufferProxy() # pylint: disable=invalid-name LocalBuffer = LocalBufferProxy() # pylint: disable=invalid-name
...@@ -250,5 +292,8 @@ def ptr(dtype: Optional[str] = None, ...@@ -250,5 +292,8 @@ def ptr(dtype: Optional[str] = None,
return handle(dtype=dtype, storage_scope=storage_scope, is_size_var=is_size_var) return handle(dtype=dtype, storage_scope=storage_scope, is_size_var=is_size_var)
def make_tensor(ptr: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32") -> tir.Buffer: def make_tensor(ptr: Var,
return Tensor.from_ptr(ptr, shape, dtype) shape: tuple[PrimExpr, ...],
dtype: str = "float32",
strides: tuple[PrimExpr, ...] = None) -> tir.Buffer:
return Tensor.from_ptr(ptr, shape, dtype, strides)
import inspect
from typing import Callable, Optional, Union from typing import Callable, Optional, Union
from tvm.tir.function import PrimFunc
import tvm.script.parser.tir.entry as _tir_entry import tvm.script.parser.tir.entry as _tir_entry
import inspect from tvm.tir.function import PrimFunc
from tvm.script.parser._core import parse, scan_macro, utils from tvm.script.parser._core import parse, scan_macro, utils
def prim_func(func: Optional[Callable] = None, def prim_func(func: Optional[Callable] = None,
private: bool = False, private: bool = False,
check_well_formed=False) -> Union[PrimFunc, Callable]: check_well_formed: bool = False) -> Union[PrimFunc, Callable]:
"""The parsing method for tir prim func, by using `@prim_func` as decorator. """The parsing method for tir prim func, by using `@prim_func` as decorator.
Parameters Parameters
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment