"test/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "765206cbd7ec459a2b78bec72bab4b3298c099c9"
Unverified Commit 1b308baf authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Language] Introduce `StridedTensor` to support non contigious torch inputs (#722)



* Update submodule 'tvm' to commit e11521e6936a827efa334588d29571fbb4620107

* Support strided tensors

* Refactor target attribute helper functions for improved clarity

* No code changes made in proxy.py and setup.py

* lint fix

* lint fix via gemini

* lint fix

* test fix

* test fix

* lint fix

* Update wrapper.py

* test fix

* Enhance test for InjectSoftwarePipeline by adding LowerOpaqueBlock transformation and updating expected function signature to use match_buffer for better clarity.

* lint fix

---------
Co-authored-by: default avatarChenggang Zhao <chenggangz@deepseek.com>
parent c369d690
......@@ -7,8 +7,6 @@ import tilelang.language as T
from tilelang.autotuner import *
from example_fusedmoe_torch import *
# tilelang.disable_cache()
@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def moe_forward_tilelang_shared(d_hidden,
......
......@@ -145,20 +145,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
clear_accum=True,
wg_wait=-1)
T.barrier_wait(kv_shared_0_r_is_ready, k % 2)
T.gemm(
Q_shared_r,
KV_shared_0_r,
acc_s_0,
transpose_B=True,
wg_wait=-1)
T.gemm(Q_shared_r, KV_shared_0_r, acc_s_0, transpose_B=True, wg_wait=-1)
T.barrier_wait(kv_shared_0_pe_is_ready, k % 2)
T.gemm(
Q_pe_local_0,
K_pe_shared_0,
acc_s_0,
transpose_B=True,
wg_wait=-1)
T.gemm(Q_pe_local_0, K_pe_shared_0, acc_s_0, transpose_B=True, wg_wait=-1)
T.wait_wgmma(0)
......@@ -261,20 +251,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
wg_wait=-1)
T.barrier_wait(kv_shared_1_r_is_ready, k % 2)
T.gemm(
Q_shared_r,
KV_shared_1_r,
acc_s_1,
transpose_B=True,
wg_wait=-1)
T.gemm(Q_shared_r, KV_shared_1_r, acc_s_1, transpose_B=True, wg_wait=-1)
T.barrier_wait(kv_shared_1_pe_is_ready, k % 2)
T.gemm(
Q_pe_local_1,
K_pe_shared_1,
acc_s_1,
transpose_B=True,
wg_wait=-1)
T.gemm(Q_pe_local_1, K_pe_shared_1, acc_s_1, transpose_B=True, wg_wait=-1)
T.wait_wgmma(0)
......@@ -308,11 +288,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
# Step 10. compute O1 with KV_shared_1_rd
T.copy(acc_s_1, acc_s_1_cast)
T.gemm(
acc_s_1_cast,
KV_shared_1_r,
acc_o_r,
wg_wait=-1)
T.gemm(acc_s_1_cast, KV_shared_1_r, acc_o_r, wg_wait=-1)
T.copy(acc_s_1_cast, SP1_shared)
T.barrier_arrive(s_shared_ready_barrier)
......
import fcntl
import functools
import hashlib
import io
import subprocess
import shutil
......@@ -12,9 +15,7 @@ from pathlib import Path
import os
import sys
import site
import hashlib
import sysconfig
import functools
import urllib.request
from packaging.version import Version
import platform
......@@ -22,7 +23,6 @@ import multiprocessing
from setuptools.command.build_ext import build_ext
import importlib
import logging
import fcntl
# Configure logging with basic settings
logging.basicConfig(
......@@ -692,15 +692,15 @@ class TilelangExtensionBuild(build_ext):
with open(md5_path, "r") as f:
cached_hash = f.read().strip()
if cached_hash == code_hash:
logger.info("Cython jit adapter is up to date, no need to compile...")
logger.info("Cython JIT adapter is up to date, no need to compile...")
need_compile = False
else:
logger.info("Cython jit adapter is out of date, need to recompile...")
logger.info("Cython JIT adapter is out of date, need to recompile...")
else:
logger.info("No cached version found for cython jit adapter, need to compile...")
logger.info("No cached version found for Cython JIT adapter, need to compile...")
if need_compile:
logger.info("Waiting for lock to compile cython jit adapter...")
logger.info("Waiting for lock to compile Cython JIT adapter...")
with open(lock_file, 'w') as lock:
fcntl.flock(lock.fileno(), fcntl.LOCK_EX)
try:
......@@ -715,7 +715,7 @@ class TilelangExtensionBuild(build_ext):
need_compile = False
if need_compile:
logger.info("Compiling cython jit adapter...")
logger.info("Compiling Cython JIT adapter...")
temp_path = cache_dir / f"temp_{code_hash}.so"
with open(md5_path, "w") as f:
......@@ -736,7 +736,7 @@ class TilelangExtensionBuild(build_ext):
except Exception as e:
if 'temp_path' in locals() and temp_path.exists():
temp_path.unlink()
raise Exception(f"Failed to compile cython jit adapter: {e}") from e
raise Exception(f"Failed to compile Cython JIT adapter: {e}") from e
finally:
if lock_file.exists():
lock_file.unlink()
......
......@@ -1702,6 +1702,76 @@ void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) {
os << "))";
}
void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op,
std::ostream &os) { // NOLINT(*)
ICHECK_EQ(op->indices.size(), 1)
<< "Load from non-flat memory not supported.";
ICHECK(!op->predicate.defined())
<< "Predicated buffer load is not supported.";
DataType value_dtype = op->dtype;
PrimExpr index = op->indices[0];
Var buffer_var = op->buffer->data;
DataType element_dtype = op->buffer->dtype;
int lanes = op->dtype.lanes();
// delcare type.
if (value_dtype.lanes() == element_dtype.lanes()) {
std::string ref = GetBufferRef(op->dtype, op->buffer.get(), index);
HandleVolatileLoads(ref, op, os);
} else {
bool can_vector_load = false;
arith::PVar<PrimExpr> base;
if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) {
const RampNode *ramp = index.as<RampNode>();
ICHECK(ramp);
can_vector_load = true;
// arith::ModularSet me = arith::Analyzer().modular_set(ramp->base);
// The condition: {k * coeff + base} divisible by the alignment for any k
// if (me->coeff % op->dtype.lanes() == 0 && me->base % op->dtype.lanes()
// == 0) {
// can_vector_load = true;
// }
}
if (value_dtype.is_float4_e2m1fn() && lanes != 1) {
// A float4_e2m1fn element has 4 bits, which is an incomplete byte.
// So we cannot vector load it.
can_vector_load = false;
}
if (can_vector_load) {
std::string ref = GetVecLoad(op->dtype, op->buffer.get(), base.Eval());
HandleVolatileLoads(ref, op, os);
} else {
std::ostringstream svalue_expr;
std::string sindex = SSAGetID(PrintExpr(index), index.dtype());
std::string vid = GetVarID(buffer_var.get());
DataType elem_type = op->dtype.element_of();
for (int i = 0; i < lanes; ++i) {
std::ostringstream value_temp;
if (!HandleTypeMatch(buffer_var.get(), elem_type)) {
value_temp << "((";
if (buffer_var.get()->dtype.is_handle()) {
auto it = alloc_storage_scope_.find(buffer_var.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, value_temp);
}
}
PrintType(elem_type, value_temp);
value_temp << "*)" << vid << ')';
} else {
value_temp << vid;
}
value_temp << '[';
PrintVecElemLoad(sindex, index.dtype(), i, value_temp);
value_temp << ']';
PrintVecElemLoadExpr(op->dtype, i, value_temp.str(), svalue_expr);
}
os << svalue_expr.str();
}
}
}
void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op,
std::ostream &os) { // NOLINT(*)
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
......
......@@ -50,6 +50,7 @@ public:
void VisitStmt_(const EvaluateNode *op) final;
void VisitStmt_(const AllocateNode *op) final;
void VisitStmt_(const AttrStmtNode *op) final;
void VisitExpr_(const BufferLoadNode *op, std::ostream &os) final;
// Override this as a work around for __grid_constant__ parameter
void AddFunction(const GlobalVar &gvar, const PrimFunc &f);
......
......@@ -22,7 +22,8 @@ struct MinOp {
}
};
template <class Reducer, int threads, int scale, int thread_offset = 0> struct AllReduce {
template <class Reducer, int threads, int scale, int thread_offset = 0>
struct AllReduce {
static_assert(threads == 1024 || threads == 512 || threads == 256 ||
threads == 128 || threads == 64 || threads == 32 ||
threads == 16 || threads == 8 || threads == 4 || threads == 2);
......
......@@ -136,11 +136,23 @@ private:
max_vector_size = gcd_base;
}
vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);
// Generate strides if not existed
auto strides = buffer->strides;
if (buffer->strides.size() == 0) {
PrimExpr stride = 1;
for (int i = indices.size() - 1; i >= 0; --i) {
strides.push_back(stride);
stride = stride * buffer->shape[i];
}
strides = Array<PrimExpr>{strides.rbegin(), strides.rend()};
}
// Generate and check element offset expression
ICHECK(indices.size() == strides.size()) << "Invalid indices and strides";
PrimExpr elem_offset = 0;
PrimExpr stride = 1;
for (int i = indices.size() - 1; i >= 0; --i) {
elem_offset = elem_offset + indices[i] * stride;
stride = stride * buffer->shape[i];
for (int i = 0; i < indices.size(); ++i) {
elem_offset += indices[i] * strides[i];
}
while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
inner_for_->extent, vector_size_,
......@@ -229,10 +241,19 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size,
ICHECK(target_vectorized_size >= 1);
if (target_vectorized_size == 1)
return true;
// bind thread range
// Extent must be divisible
if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_vectorized_size),
0))
return false;
// The base offset must be divisible
if (!analyzer->CanProveEqual(
FloorMod(Substitute(expr, {{var, 0}}), target_vectorized_size), 0)) {
return false;
}
// Bind thread range
Var v0("v0"), v1("v1");
analyzer->Bind(v0, Range(0, target_vectorized_size));
analyzer->Bind(v1, Range(0, analyzer->Simplify(FloorDiv(
......@@ -241,7 +262,8 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size,
Substitute(expr, {{var, v0 + v1 * target_vectorized_size}}));
Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size));
PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed);
// This simplify is necessary for thread region specifiled
// This simplify is necessary for thread region specified
// optimizations.
expr_vectorized = analyzer->Simplify(expr_vectorized);
auto ramp_node = expr_vectorized.as<RampNode>();
......
......@@ -28,8 +28,8 @@ def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16")
out_idx=[1],
target="cuda",
pass_configs={
"tl.disable_warp_specialized": True,
"tl.disable_tma_lower": True
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True
})
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a)
......@@ -42,5 +42,49 @@ def test_tilelang_copy():
run_tilelang_copy(M=1024, N=576, block_M=32, block_N=576, dtype="float")
def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype="float16"):
@T.prim_func
def main(
A: T.StridedTensor((M, N), (NN, 1), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = A[by * block_M + i, bx * block_N + j]
return main
def run_tilelang_copy_with_stride(M=1024,
N=1024,
NN=2048,
block_M=128,
block_N=128,
dtype="float16"):
if isinstance(NN, int):
assert NN > N, "NN must be greater than N"
program = tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype)
kernel = tilelang.compile(
program,
out_idx=[1],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
})
if isinstance(NN, T.Var):
NN = N * 2
a = torch.randn(M, NN, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a[:, :N])
torch.testing.assert_close(b, a[:, :N], rtol=1e-2, atol=1e-2)
def test_tilelang_copy_with_stride():
run_tilelang_copy_with_stride(M=1024, N=1024, NN=2048, block_M=128, block_N=128)
run_tilelang_copy_with_stride(M=1024, N=1024, NN=T.symbolic("NN"), block_M=128, block_N=128)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -9,6 +9,7 @@ def _check(original, transformed):
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
mod = tl.transform.InjectSoftwarePipeline()(mod)
mod = tl.transform.Simplify()(mod)
mod = tl.transform.LowerOpaqueBlock()(mod)
tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"),
True)
......@@ -39,32 +40,16 @@ def test_trival_pipeline():
C[tx, i] = B[tx, 0] + T.float32(1)
@T.prim_func
def expected(A: T.Buffer((16, 1), "float32"), C: T.Buffer((16, 1), "float32")):
for tx in T.thread_binding(16, thread="threadIdx.x"):
with T.block():
T.reads(A[tx, 0])
T.writes(C[tx, 0])
B = T.alloc_buffer((2, 16, 1), scope="shared")
with T.block():
T.reads(A[tx, 0])
T.writes(B[0, tx, 0])
B[0, tx, 0] = A[tx, 0] * T.float32(2.0)
with T.block():
T.reads(A[tx, 1:1], B[0:2, tx, 0])
T.writes(B[1:1, tx, 0], C[tx, 0:0])
for i in range(0):
with T.block():
T.reads(A[tx, i + 1])
T.writes(B[i + 1, tx, 0])
B[i + 1, tx, 0] = A[tx, i + 1] * T.float32(2.0)
with T.block():
T.reads(B[i, tx, 0])
T.writes(C[tx, i])
C[tx, i] = B[i, tx, 0] + T.float32(1.0)
with T.block():
T.reads(B[0, tx, 0])
T.writes(C[tx, 0])
C[tx, 0] = B[0, tx, 0] + T.float32(1.0)
def expected(A_handle: T.handle, C_handle: T.handle):
A = T.match_buffer(A_handle, (16, 1), strides=(1, 1))
C = T.match_buffer(C_handle, (16, 1), strides=(1, 1))
tx = T.launch_thread("threadIdx.x", 16)
B = T.decl_buffer((2, 16, 1), scope="shared")
B[0, tx, 0] = A[tx, 0] * T.float32(2.0)
for i in range(0):
B[i + 1, tx, 0] = A[tx, i + 1] * T.float32(2.0)
C[tx, i] = B[i, tx, 0] + T.float32(1.0)
C[tx, 0] = B[0, tx, 0] + T.float32(1.0)
_check(before, expected)
......
......@@ -124,8 +124,10 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.InjectFenceProxy()(mod)
mod = tilelang.transform.LowerOpaqueBlock()(mod)
mod = tir.transform.NarrowDataType(32)(mod)
mod = tilelang.transform.ConfigIndexBitwidth()(mod)
mod = tilelang.transform.FlattenBuffer()(mod)
# ConfigIndexBitwidth must be applied after FlattenBuffer
# as it will flatten index computing
mod = tilelang.transform.ConfigIndexBitwidth()(mod)
mod = tir.transform.Simplify()(mod)
mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod)
......
......@@ -155,21 +155,31 @@ class CtypesKernelAdapter(BaseKernelAdapter):
adapter._post_init()
return adapter
def _process_dynamic_symbolic(self):
def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int, int]]:
"""Extract information about dynamic shapes from the TIR function.
Maps symbolic variables to their corresponding (buffer_index, shape_dimension)
Maps symbolic variables to their corresponding (id, buffer_index, dimension)
for runtime shape resolution.
id represents shape or stride, 0 represents shape, 1 represents stride
"""
func = self.prim_func
params = func.params
buffer_map = func.buffer_map
dynamic_symbolic_map = {}
for i, param in enumerate(params):
buffer = buffer_map[param]
for j, shape in enumerate(buffer.shape):
if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map):
dynamic_symbolic_map[shape] = (i, j)
if param in buffer_map:
buffer = buffer_map[param]
for j, shape in enumerate(buffer.shape):
if (isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and
(shape not in params)):
dynamic_symbolic_map[shape] = (0, i, j)
for i, param in enumerate(params):
if param in buffer_map:
buffer = buffer_map[param]
for j, stride in enumerate(buffer.strides):
if (isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and
(stride not in params)):
dynamic_symbolic_map[stride] = (1, i, j)
return dynamic_symbolic_map
def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None):
......@@ -228,8 +238,11 @@ class CtypesKernelAdapter(BaseKernelAdapter):
args.append(tensor)
# dynamic symbolics
for _, (buffer_idx, shape_idx) in self.dynamic_symbolic_map.items():
args.append(ins[buffer_idx].shape[shape_idx])
for _, (ref_id, buffer_idx, shape_idx) in self.dynamic_symbolic_map.items():
if ref_id == 0:
args.append(ins[buffer_idx].shape[shape_idx])
else:
args.append(ins[buffer_idx].stride(shape_idx))
# if stream is not None, we need to pass the stream to the library
if stream is None:
......
"""The profiler and convert to torch utils"""
from ..base import BaseKernelAdapter
import ctypes
import fcntl
import hashlib
import logging
import site
import sys
import sysconfig
import torch
import os
from pathlib import Path
from typing import List, Optional, Union, Callable, Dict, Tuple, Any
from tilelang import tvm as tvm
from tvm.target import Target
from tilelang.engine.param import KernelParam
from tvm import tir
from tvm.relax import TensorType
from tilelang.jit.adapter.base import BaseKernelAdapter
from tilelang.jit.adapter.wrapper import TLWrapper
from tilelang.jit.adapter.libgen import LibraryGenerator
from tilelang.jit.adapter.utils import is_cuda_target, is_hip_target, is_cpu_target
......@@ -15,15 +26,6 @@ from tilelang.utils.target import determine_target
from tilelang.utils.language import retrieve_func_from_module
from tilelang.utils.tensor import map_torch_type
from tilelang.contrib.cc import get_cplus_compiler
import torch
import sys
import sysconfig
import hashlib
import os
import fcntl
from pathlib import Path
import logging
import site
logger = logging.getLogger(__name__)
......@@ -116,15 +118,15 @@ with open(cython_wrapper_path, "r") as f:
with open(md5_path, "r") as f:
cached_hash = f.read().strip()
if cached_hash == code_hash:
logger.debug("Cython jit adapter is up to date, no need to compile...")
logger.debug("Cython JIT adapter is up to date, no need to compile...")
need_compile = False
else:
logger.info("Cython jit adapter is out of date, need to recompile...")
logger.info("Cython JIT adapter is out of date, need to recompile...")
else:
logger.info("No cached version found for cython jit adapter, need to compile...")
logger.info("No cached version found for Cython JIT adapter, need to compile...")
if need_compile:
logger.info("Waiting for lock to compile cython jit adapter...")
logger.info("Waiting for lock to compile Cython JIT adapter...")
with open(lock_file, 'w') as lock:
fcntl.flock(lock.fileno(), fcntl.LOCK_EX)
try:
......@@ -138,7 +140,7 @@ with open(cython_wrapper_path, "r") as f:
need_compile = False
if need_compile:
logger.info("Compiling cython jit adapter...")
logger.info("Compiling Cython JIT adapter...")
temp_path = cache_dir / f"temp_{code_hash}.so"
with open(md5_path, "w") as f:
......@@ -159,7 +161,7 @@ with open(cython_wrapper_path, "r") as f:
except Exception as e:
if 'temp_path' in locals() and temp_path.exists():
temp_path.unlink()
raise Exception(f"Failed to compile cython jit adapter: {e}") from e
raise Exception(f"Failed to compile Cython JIT adapter: {e}") from e
finally:
if lock_file.exists():
lock_file.unlink()
......@@ -195,11 +197,14 @@ class CythonKernelAdapter(BaseKernelAdapter):
ptr_map: Optional[Dict[int, str]] = None
# Maps buffer variables to their corresponding dtypes
buffer_dtype_map: Optional[Dict[tir.Var, Tuple[int, torch.dtype]]] = None
# Maps buffer variables to their corresponding static shapes
# {
# "A": [(0, 16), (1, 16)] -> represents A.shape = (16, 16)
# Maps buffer variables to their corresponding static shapes and strides,
# e.g., {
# "A": [(0, 16), (1, 16)] -> represents A.shape/strides = (16, 16)
# }
static_shape_map: Optional[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]]] = None
static_strides_map: Optional[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]]] = None
# Contains contiguous buffers
static_contiguous_list: Optional[List[tir.Var]] = None
# Maps buffer variables to their corresponding devices
buffer_device_map: Optional[Dict[tir.Var, Tuple[int, torch.device]]] = None
# Pass configs for the compiler
......@@ -239,9 +244,13 @@ class CythonKernelAdapter(BaseKernelAdapter):
self.dynamic_symbolic_map = self._process_dynamic_symbolic()
self.buffer_dtype_map = self._process_buffer_dtype()
self.ptr_map = self._process_ptr_map()
self.static_shape_map = self._process_static_shape()
self.buffer_device_map = self._process_buffer_device()
static_buffer_infos = self._process_static_buffer_infos()
self.static_shape_map = static_buffer_infos[0]
self.static_strides_map = static_buffer_infos[1]
self.static_contiguous_list = static_buffer_infos[2]
self.verbose = verbose
self.wrapper = TLWrapper(self.target)
self.lib_generator = LibraryGenerator(self.target, verbose=verbose)
......@@ -269,6 +278,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
self.cython_wrapper.set_dynamic_symbolic_map(self.dynamic_symbolic_map)
self.cython_wrapper.set_buffer_dtype_map(self.buffer_dtype_map)
self.cython_wrapper.set_static_shape_map(self.static_shape_map)
self.cython_wrapper.set_static_strides_map(self.static_strides_map)
self.cython_wrapper.set_static_contiguous_list(self.static_contiguous_list)
self.cython_wrapper.set_buffer_device_map(self.buffer_device_map)
self.cython_wrapper.set_ptr_map(self.ptr_map)
self._post_init()
......@@ -301,10 +312,14 @@ class CythonKernelAdapter(BaseKernelAdapter):
adapter.dynamic_symbolic_map = adapter._process_dynamic_symbolic()
adapter.buffer_dtype_map = adapter._process_buffer_dtype()
adapter.static_shape_map = adapter._process_static_shape()
adapter.ptr_map = adapter._process_ptr_map()
adapter.buffer_device_map = adapter._process_buffer_device()
static_buffer_infos = adapter._process_static_buffer_infos()
adapter.static_shape_map = static_buffer_infos[0]
adapter.static_strides_map = static_buffer_infos[1]
adapter.static_contiguous_list = static_buffer_infos[2]
adapter.verbose = verbose
adapter.lib_generator = LibraryGenerator(adapter.target, verbose=verbose)
adapter.lib_generator.assign_pass_configs(pass_configs)
......@@ -322,17 +337,20 @@ class CythonKernelAdapter(BaseKernelAdapter):
adapter.cython_wrapper.set_dynamic_symbolic_map(adapter.dynamic_symbolic_map)
adapter.cython_wrapper.set_buffer_dtype_map(adapter.buffer_dtype_map)
adapter.cython_wrapper.set_static_shape_map(adapter.static_shape_map)
adapter.cython_wrapper.set_static_strides_map(adapter.static_strides_map)
adapter.cython_wrapper.set_static_contiguous_list(adapter.static_contiguous_list)
adapter.cython_wrapper.set_buffer_device_map(adapter.buffer_device_map)
adapter.cython_wrapper.set_ptr_map(adapter.ptr_map)
adapter._post_init()
return adapter
def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int]]:
def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int, int]]:
"""Extract information about dynamic shapes from the TIR function.
Maps symbolic variables to their corresponding (buffer_index, shape_dimension)
Maps symbolic variables to their corresponding (id, buffer_index, dimension)
for runtime shape resolution.
id represents shape or stride, 0 represents shape, 1 represents stride
"""
func = self.prim_func
params = func.params
......@@ -344,7 +362,14 @@ class CythonKernelAdapter(BaseKernelAdapter):
for j, shape in enumerate(buffer.shape):
if (isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and
(shape not in params)):
dynamic_symbolic_map[shape] = (i, j)
dynamic_symbolic_map[shape] = (0, i, j)
for i, param in enumerate(params):
if param in buffer_map:
buffer = buffer_map[param]
for j, stride in enumerate(buffer.strides):
if (isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and
(stride not in params)):
dynamic_symbolic_map[stride] = (1, i, j)
return dynamic_symbolic_map
def _process_buffer_dtype(self) -> Dict[tir.Var, Tuple[int, torch.dtype]]:
......@@ -377,7 +402,10 @@ class CythonKernelAdapter(BaseKernelAdapter):
ptr_map[i] = param.name
return ptr_map
def _process_static_shape(self) -> Dict[tir.Var, List[Tuple[int, int]]]:
def _process_static_buffer_infos(self) -> \
Tuple[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]],
Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]],
List[Tuple[tir.Var]]]:
"""Extract information about static shapes from the TIR function.
Maps buffer variables to their corresponding static shapes.
......@@ -386,17 +414,27 @@ class CythonKernelAdapter(BaseKernelAdapter):
params = func.params
buffer_map = func.buffer_map
static_shape_map = {}
static_strides_map = {}
static_contiguous_list = list()
for i, param in enumerate(params):
if param in buffer_map:
buffer = buffer_map[param]
name = buffer.name
shape = buffer.shape
static_shape = []
for j, s in enumerate(shape):
static_shape, static_strides = [], []
for j, s in enumerate(buffer.shape):
if isinstance(s, tir.IntImm):
static_shape.append((j, s.value))
static_shape_map[name] = (i, static_shape)
return static_shape_map
for j, s in enumerate(buffer.strides):
if isinstance(s, tir.IntImm):
static_strides.append((j, s.value))
is_contiguous, prod = True, 1
for dim, stride in reversed(list(zip(buffer.shape, buffer.strides))):
is_contiguous &= bool(stride == prod)
prod *= dim
static_shape_map[buffer.name] = (i, static_shape)
static_strides_map[buffer.name] = (i, static_strides)
if is_contiguous:
static_contiguous_list.append((i, buffer.name))
return static_shape_map, static_strides_map, static_contiguous_list
def _process_buffer_device(self) -> Dict[tir.Var, Tuple[int, torch.device]]:
"""Extract information about buffer devices from the TIR function.
......@@ -473,7 +511,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
@property
def is_dynamic(self):
"""Indicates whether the kernel handles dynamic shapes."""
return (self.dynamic_symbolic_map is not None and len(self.dynamic_symbolic_map) > 0)
return self.dynamic_symbolic_map is not None and len(self.dynamic_symbolic_map) > 0
def get_kernel_source(self, kernel_only: bool = False):
"""Returns the source code of the compiled kernel."""
......
......@@ -11,17 +11,19 @@ from tilelang.utils.tensor import map_torch_type
cdef class CythonKernelWrapper:
# Class attributes to store kernel configuration and library reference
cdef:
object dynamic_symbolic_map # Maps dynamic dimensions to their corresponding tensor indices
object buffer_device_map # Maps buffer variables to their corresponding devices
object buffer_dtype_map # Maps buffer variables to their corresponding dtypes
object static_shape_map # Maps buffer variables to their corresponding static shapes
object ptr_map # Maps pointer arguments to their corresponding buffer indices
list result_idx # Indices of output tensors in the params list
list params # List of parameter specifications (includes both inputs and outputs)
object lib # Reference to the compiled library containing the kernel
object dynamic_symbolic_map # Maps dynamic dimensions to their corresponding tensor indices
object buffer_device_map # Maps buffer variables to their corresponding devices
object buffer_dtype_map # Maps buffer variables to their corresponding dtypes
object static_shape_map # Maps buffer variables to their corresponding static shapes
object static_strides_map # Maps buffer variables to their corresponding static strides
object static_contiguous_list # A list contains contiguous buffers
object ptr_map # Maps pointer arguments to their corresponding buffer indices
list result_idx # Indices of output tensors in the params list
list params # List of parameter specifications (includes both inputs and outputs)
object lib # Reference to the compiled library containing the kernel
# Add new cache attributes
list param_dtypes # Cache for parameter dtypes
list param_shapes # Cache for parameter shapes as native Python lists
list param_dtypes # Cache for parameter dtypes
list param_shapes # Cache for parameter shapes as native Python lists
object get_current_device
def __cinit__(self, result_idx, params, lib):
......@@ -57,6 +59,14 @@ cdef class CythonKernelWrapper:
self.static_shape_map = static_shape_map
return self
def set_static_strides_map(self, static_strides_map):
self.static_strides_map = static_strides_map
return self
def set_static_contiguous_list(self, static_contiguous_list):
self.static_contiguous_list = static_contiguous_list
return self
def set_ptr_map(self, ptr_map):
self.ptr_map = ptr_map
return self
......@@ -94,15 +104,41 @@ cdef class CythonKernelWrapper:
cpdef void _check_static_shape(self, list tensor_list):
for param, (buffer_idx, shape_list) in self.static_shape_map.items():
tensor = tensor_list[buffer_idx]
if isinstance(tensor, torch.Tensor):
for shape_idx, expected_shape in shape_list:
actual_shape = tensor.shape[shape_idx]
if actual_shape != expected_shape:
raise ValueError(
f"Static shape mismatch for parameter {param}: "
f"expected {expected_shape} at index {shape_idx}, "
f"got {actual_shape}"
)
if not isinstance(tensor, torch.Tensor):
# otherwise, maybe torch.data_ptr() for T.ptr inputs
continue
for shape_idx, expected_shape in shape_list:
actual_shape = tensor.shape[shape_idx]
if actual_shape != expected_shape:
raise ValueError(
f"Static shape mismatch for parameter {param}: "
f"expected {expected_shape} at index {shape_idx}, "
f"got {actual_shape}"
)
cpdef void _check_static_strides(self, list tensor_list):
for param, (buffer_idx, strides_list) in self.static_strides_map.items():
tensor = tensor_list[buffer_idx]
if not isinstance(tensor, torch.Tensor):
# otherwise, maybe torch.data_ptr() for T.ptr inputs
continue
for stride_idx, expected_stride in strides_list:
actual_stride = tensor.stride(stride_idx)
if actual_stride != expected_stride:
raise ValueError(
f"Static stride mismatch for parameter {param}: "
f"expected {expected_stride} at index {stride_idx}, "
f"got {actual_stride}"
)
cpdef void _check_static_contiguous(self, list tensor_list):
for buffer_idx, param in self.static_contiguous_list:
tensor = tensor_list[buffer_idx]
if not isinstance(tensor, torch.Tensor):
# otherwise, maybe torch.data_ptr() for T.ptr inputs
continue
if not tensor.is_contiguous():
raise ValueError(f"Expected parameter {param} to be a contiguous tensor")
cpdef forward(self, list inputs, int64_t stream = -1, bint skip_tensor_validation = False):
# Validate input dimensions and prepare for kernel execution
......@@ -140,7 +176,7 @@ cdef class CythonKernelWrapper:
if isinstance(s, tir.Var):
for key in self.dynamic_symbolic_map:
if(str(s) == str(key)):
ref_tensor_idx, ref_shape_idx = self.dynamic_symbolic_map[key]
ref_id, ref_tensor_idx, ref_shape_idx = self.dynamic_symbolic_map[key]
shape.append(tensor_list[ref_tensor_idx].shape[ref_shape_idx])
else: # Already converted to Python int during initialization
shape.append(s)
......@@ -155,6 +191,13 @@ cdef class CythonKernelWrapper:
else:
tensor = inputs[ins_idx]
ins_idx += 1
# TODO(chenggang): remove this check or rewrite by ourselves?
if isinstance(tensor, torch.Tensor) and tensor._base is not None and not tensor.is_contiguous():
base_tensor = tensor._base.as_strided(tensor._base.shape, tensor.stride())
if torch._debug_has_internal_overlap(base_tensor):
raise ValueError(f"Cannot use an overlapping tensor"
f"(shape={tensor.shape}, strides={tensor.stride()}, "
f"overlap={torch._debug_has_internal_overlap(base_tensor)}) as the kernel input")
tensor_list.append(tensor)
# Convert tensor pointers to C void pointers for kernel call
......@@ -172,8 +215,6 @@ cdef class CythonKernelWrapper:
call_args = []
for i, tensor in enumerate(tensor_list):
if isinstance(tensor, torch.Tensor):
if not tensor.is_contiguous():
raise ValueError(f"Input tensor at index {i} must be contiguous")
call_args.append(ctypes.c_void_p(tensor.data_ptr()))
elif isinstance(tensor, (int, float, bool)):
if i in self.ptr_map:
......@@ -191,10 +232,15 @@ cdef class CythonKernelWrapper:
self._check_buffer_device(tensor_list)
self._check_buffer_dtype(tensor_list)
self._check_static_shape(tensor_list)
self._check_static_strides(tensor_list)
self._check_static_contiguous(tensor_list)
# Add dynamic dimension values to kernel arguments
for _, (buffer_idx, shape_idx) in self.dynamic_symbolic_map.items():
call_args.append(tensor_list[buffer_idx].shape[shape_idx])
for _, (ref_id, buffer_idx, shape_idx) in self.dynamic_symbolic_map.items():
if ref_id == 0:
call_args.append(tensor_list[buffer_idx].shape[shape_idx])
else:
call_args.append(tensor_list[buffer_idx].stride(shape_idx))
# Add CUDA stream to kernel arguments
call_args.append(ctypes.c_void_p(stream))
......
......@@ -234,7 +234,10 @@ class TLCUDASourceWrapper(object):
dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func)
function_args = []
# Collect function arguments based on primary function's parameters and buffer mappings
# QA(@lei): Why not use device_mod.params?
# device func lack buffer map (to convert buffer handle to buffer)
for param in self.prim_func.params:
if param in self.prim_func.buffer_map:
buffer = self.prim_func.buffer_map[param]
......@@ -484,12 +487,26 @@ class TLCUDASourceWrapper(object):
def get_dynamic_symbolic_set(self, prim_func):
# Determine the set of dynamic symbols used in the function
dynamic_symbolic_set: List[str] = []
def unique_push_back(name: str):
if name not in dynamic_symbolic_set:
dynamic_symbolic_set.append(name)
for param in prim_func.params:
if param in prim_func.buffer_map:
buffer = prim_func.buffer_map[param]
for dim in buffer.shape:
if isinstance(dim, tvm.tir.Var) and (dim.name not in dynamic_symbolic_set):
dynamic_symbolic_set.append(dim.name)
if isinstance(dim, tvm.tir.Var):
unique_push_back(dim.name)
# Note: In buffer definitions, any dynamic symbols appearing in strides are listed after those in the shape.
for param in prim_func.params:
if param in prim_func.buffer_map:
buffer = prim_func.buffer_map[param]
for stride in buffer.strides:
if isinstance(stride, tvm.tir.Var):
unique_push_back(stride.name)
return dynamic_symbolic_set
def get_init_func(self):
......@@ -549,6 +566,19 @@ class TLCUDASourceWrapper(object):
return function
raise ValueError("Cannot find primary function in the module.")
@property
def device_func(self):
if len(self.device_mod.get_global_vars()) == 1:
return self.device_mod[self.device_mod.get_global_vars()[0]]
elif "main" in self.device_mod:
return self.device_mod["main"]
else:
for _, function in self.device_mod.functions.items():
attr = function.attrs
if "tir.is_global_func" in attr and attr["tir.is_global_func"]:
return function
raise ValueError("Cannot find primary function in the module.")
class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
"""
......
......@@ -17,6 +17,7 @@ from .proxy import (
make_tensor, # noqa: F401
Buffer, # noqa: F401
Tensor, # noqa: F401
StridedTensor, # noqa: F401
FragmentBuffer, # noqa: F401
SharedBuffer, # noqa: F401
LocalBuffer, # noqa: F401
......
"""The language interface for tl programs."""
from __future__ import annotations
from typing import Any, Optional, Sequence, SupportsIndex, TYPE_CHECKING
from typing import Any, Optional, Sequence, SupportsIndex, TYPE_CHECKING, Tuple, Union
from typing_extensions import Self
from tvm import tir
......@@ -53,7 +53,8 @@ class BufferProxy:
def from_ptr(self,
pointer_var: Var,
shape: tuple[PrimExpr, ...],
dtype: str = "float32") -> Buffer:
dtype: str = "float32",
strides: tuple[PrimExpr, ...] = None) -> Buffer:
"""Create a buffer from a pointer, shape, and data type.
Args:
......@@ -64,7 +65,7 @@ class BufferProxy:
Returns:
A buffer created from the given parameters
"""
return match_buffer(pointer_var, shape, dtype=dtype)
return match_buffer(pointer_var, shape, dtype=dtype, strides=strides)
class BaseTensorProxy:
......@@ -110,16 +111,17 @@ class BaseTensorProxy:
)
def __getitem__(self, keys) -> tir.Buffer:
if not isinstance(keys, tuple):
return self(keys)
if len(keys) >= 2 and not isinstance(keys[1], str):
return self(keys)
assert isinstance(keys, tuple)
# Single argument (the shape)
if all([type(s) not in (tuple, str, list) for s in keys]):
keys = (keys,)
return self(*keys)
def from_ptr(self,
pointer_var: Var,
shape: tuple[PrimExpr, ...],
dtype: str = "float32") -> tir.Buffer:
dtype: str = "float32",
strides: tuple[PrimExpr, ...] = None) -> tir.Buffer:
"""Create a buffer from a pointer, shape, and data type.
Args:
......@@ -130,16 +132,51 @@ class BaseTensorProxy:
Returns:
A buffer created from the given parameters
"""
return match_buffer(pointer_var, shape, dtype=dtype)
return match_buffer(pointer_var, shape, dtype=dtype, strides=strides)
class TensorProxy(BaseTensorProxy):
"""Main tensor proxy class for global scope buffers.
This class implements the default tensor proxy with global memory scope,
inheriting all functionality from BaseTensorProxy without modifications.
the tensor should be by default contiguous.
"""
@staticmethod
def _construct_strides(shape: Tuple[Any]):
s, strides = 1, [1]
for dim in shape[:0:-1]:
s *= dim
strides.append(s)
return tuple(reversed(strides))
def __call__(self,
shape: Union[Tuple[Any], PrimExpr, int],
dtype: str = "float32",
data=None) -> tir.Buffer:
if isinstance(shape, (int, PrimExpr)):
shape = (shape,)
return super().__call__(
shape, dtype=dtype, strides=TensorProxy._construct_strides(shape), data=data)
class StridedTensorProxy(BaseTensorProxy):
"""Main tensor proxy class for global scope buffers, with strides supported.
This class implements the default tensor proxy with global memory scope, with the stride information required.
"""
def __call__(self,
shape: Tuple[Any],
strides: Tuple[Any],
dtype: str = "float32") -> tir.Buffer:
if len(shape) != len(strides):
raise ValueError("Invalid shape/strides' dimensions")
if not bool(strides[-1] == 1):
# TODO(chenggang): shall we support non-contiguous even for the last dimension?
raise ValueError("The stride of the last dimension must be 1 (contiguous)")
return super().__call__(shape, dtype=dtype, strides=strides)
class FragmentBufferProxy(BaseTensorProxy):
"""Proxy class for fragment memory buffers.
......@@ -204,12 +241,16 @@ if TYPE_CHECKING:
def from_ptr(cls,
pointer_var: Var,
shape: Sequence[PrimExpr, ...],
dtype: str = "float32") -> Self:
dtype: str = "float32",
strides: tuple[PrimExpr, ...] = None) -> Self:
...
class Tensor(BaseTensor):
...
class StridedTensor(BaseTensor):
...
class FragmentBuffer(BaseTensor):
...
......@@ -220,6 +261,7 @@ if TYPE_CHECKING:
...
else:
Tensor = TensorProxy() # pylint: disable=invalid-name
StridedTensor = StridedTensorProxy() # pylint: disable=invalid-name
FragmentBuffer = FragmentBufferProxy() # pylint: disable=invalid-name
SharedBuffer = SharedBufferProxy() # pylint: disable=invalid-name
LocalBuffer = LocalBufferProxy() # pylint: disable=invalid-name
......@@ -250,5 +292,8 @@ def ptr(dtype: Optional[str] = None,
return handle(dtype=dtype, storage_scope=storage_scope, is_size_var=is_size_var)
def make_tensor(ptr: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32") -> tir.Buffer:
return Tensor.from_ptr(ptr, shape, dtype)
def make_tensor(ptr: Var,
shape: tuple[PrimExpr, ...],
dtype: str = "float32",
strides: tuple[PrimExpr, ...] = None) -> tir.Buffer:
return Tensor.from_ptr(ptr, shape, dtype, strides)
import inspect
from typing import Callable, Optional, Union
from tvm.tir.function import PrimFunc
import tvm.script.parser.tir.entry as _tir_entry
import inspect
from tvm.tir.function import PrimFunc
from tvm.script.parser._core import parse, scan_macro, utils
def prim_func(func: Optional[Callable] = None,
private: bool = False,
check_well_formed=False) -> Union[PrimFunc, Callable]:
check_well_formed: bool = False) -> Union[PrimFunc, Callable]:
"""The parsing method for tir prim func, by using `@prim_func` as decorator.
Parameters
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment