Commit 7b74bb01 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[JIT] Enhance cython/ctypes wrapper for tma descriptor (#126)



* refactor code

* enhance tutorial

* Enhance error handling and code generation in CUDA and TileLang components

This commit introduces several improvements across multiple files:
- Added more informative error messages in GEMM layout checks
- Updated CUDA codegen to support more flexible function signature generation
- Improved TMA descriptor initialization and kernel dispatch logic
- Refined library generation and source code parsing utilities
- Enhanced error handling in various adapter and wrapper classes

* Add thread tag validation for warp specialization

Introduce a ThreadTagChecker to validate that a PrimFunc only uses threadIdx.x before applying warp specialization. This prevents unintended transformations on kernels with complex thread binding and provides a clear warning to users about potential issues with warp specialization.

* Update TileLang Profiling and Compilation in Flash Decoding Examples

Refactor the profiling and compilation workflow in two flash decoding example scripts:
- Replace `tilelang.lower()` and `tilelang.Profiler()` with `tilelang.compile()`
- Simplify profiler initialization using `get_profiler()`
- Update method calls to use the new profiler and compiled kernel objects
- Maintain existing performance benchmarking and validation logic

* Refactor and clean up code formatting in TileLang testing and adapter modules

This commit includes several code style and formatting improvements:
- Adjust whitespace and line breaks in test files
- Improve code formatting in CUDA source wrapper and adapter utilities
- Enhance readability of function calls and argument handling
- Remove unnecessary whitespace and standardize indentation
- Simplify function signatures and argument parsing

* Refactor CUDA codegen and improve code formatting

This commit includes several improvements to CUDA code generation and formatting:
- Enhance function signature generation in CodeGenTileLangCUDA
- Improve code formatting and readability in CUDA-related files
- Simplify parameter handling and type annotations
- Clean up whitespace and line breaks in codegen and layout files

---------
Co-authored-by: default avatarUbuntu <dlisuser@h100testl730RPS.xu5snccwrbtejcqqalluoku5hb.xx.internal.cloudapp.net>
parent ba311311
......@@ -347,14 +347,13 @@ if __name__ == "__main__":
program = flashattn(
batch, heads, groups, kv_seqlen, dim, tune=args.tune)(
block_N=128, block_H=64, num_split=8, num_stages=2, threads=128)
mod, params = tilelang.lower(program)
mod = tilelang.Profiler(mod, params, [6], tilelang.TensorSupplyType.Auto)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01, max_mismatched_ratio=0.01)
kernel = tilelang.compile(program, out_idx=[6])
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
print("All checks pass.")
latency = mod.do_bench(ref_program, warmup=500)
latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = mod.do_bench(mod.func, warmup=500, profiler="auto")
latency = profiler.do_bench(kernel.rt_module, warmup=500, profiler="auto")
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
......
......@@ -304,14 +304,14 @@ if __name__ == "__main__":
BLOCK_N = 64 # if D_HEAD <= 128 else 32
program = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, causal, BLOCK_M, BLOCK_N)
ref_program = partial(ref_program, causal=causal)
mod, params = tilelang.lower(program)
mod = tilelang.Profiler(mod, params, [5], tilelang.TensorSupplyType.Normal)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
mod = tilelang.compile(program, out_idx=[5], target="cuda", execution_backend="dlpack")
profiler = mod.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks passed!")
latency = mod.do_bench(ref_program, warmup=500)
latency = profiler.do_bench(ref_program, warmup=500)
print("{:.2f} ms".format(latency))
print("{:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = mod.do_bench(mod, n_warmup=10, n_repeat=10, profiler="tvm")
latency = profiler.do_bench(profiler.mod, n_warmup=10, n_repeat=10, profiler="tvm")
print("{:.4f} ms".format(latency))
print("{:.2f} TFlops".format(total_flops / latency * 1e-9))
......@@ -40,11 +40,12 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy tile of A
# This is a sugar syntax for parallelized copy
# for i, k in T.Parallel(M, block_K):
# A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
T.copy(A[by * block_M, ko * block_K], A_shared)
# Demonstrate parallelized copy from global to shared for B
for k, j in T.Parallel(block_K, block_N):
B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]
# Copy tile of B
T.copy(B[ko * block_K, bx * block_N], B_shared)
# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
......@@ -63,7 +64,8 @@ func = matmul(1024, 1024, 1024, 128, 128, 32)
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda")
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", execution_backend="cython")
# jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", execution_backend="dlpack")
# 3. Test the kernel in Python with PyTorch data
import torch
......@@ -75,6 +77,7 @@ b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
c = jit_kernel(a, b)
print(c)
# Reference multiplication using PyTorch
ref_c = a @ b
......@@ -83,11 +86,11 @@ torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
cuda_source = jit_kernel.get_kernel_source()
print("Generated CUDA kernel:\n", cuda_source)
# cuda_source = jit_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel
profiler = jit_kernel.get_profiler()
profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench()
......
......@@ -298,8 +298,9 @@ Layout makeHalfBankSwizzleLayout(int stride, int continuous, int element_size) {
Var i = InputPlaceholder(0);
Var j = InputPlaceholder(1);
int vector_size = 128 / element_size;
ICHECK(stride % 8 == 0);
ICHECK(continuous % (vector_size * 4) == 0);
ICHECK(stride % 8 == 0) << "stride=" << stride;
ICHECK(continuous % (vector_size * 4) == 0)
<< "continuous=" << continuous << ", vector_size=" << vector_size;
PrimExpr ts = FloorDiv(i, 8);
PrimExpr s = FloorMod(i, 8);
PrimExpr tc = FloorDiv(FloorDiv(j, vector_size), 4);
......
......@@ -84,7 +84,7 @@ public:
PrimExpr threadIdx_z_ext = Integer(1);
};
void CodeGenTileLangCUDA::PrintExtraAttrs(const PrimFunc &f, std::ostream &os) {
void CodeGenTileLangCUDA::PrintExtraAttrs(const PrimFunc &f) {
LaunchConfigExtractor extractor;
extractor(f->body);
arith::Analyzer analyzer;
......@@ -1633,7 +1633,72 @@ void CodeGenTileLangCUDA::PrintVecElemLoadExpr(DataType t, int i,
return;
}
void CodeGenTileLangCUDA::AddFunction(const PrimFunc &f) {
void CodeGenTileLangCUDA::PrintFunctionSignature(const String &function_name,
const PrimFunc &func,
std::ostream &os) {
PrintFuncPrefix(os);
CodeGenC::PrintType(func->ret_type, os);
CodeGenC::PrintExtraAttrs(func, os);
bool no_alias = func->HasNonzeroAttr(tir::attr::kNoAlias);
os << " " << function_name << "(";
for (size_t i = 0; i < func->params.size(); ++i) {
tir::Var v = func->params[i];
std::string vid = AllocVarID(v.get());
if (i > 0) {
os << ", ";
}
if (v.dtype().is_handle()) {
// work around for grid constant parameters.
if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
if (ptr->storage_scope == "grid_constant") {
os << "__grid_constant__ const ";
CodeGenC::PrintType(ptr->element_type, os);
os << ' ' << vid;
continue;
}
}
auto it = alloc_storage_scope_.find(v.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, os);
}
CodeGenC::PrintType(GetType(v), os);
if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
if (auto *prim = ptr->element_type.as<PrimTypeNode>()) {
RegisterHandleType(v.get(), prim->dtype);
}
}
if (no_alias) {
PrintRestrict(v, os);
}
} else {
CodeGenC::PrintType(GetType(v), os);
}
os << ' ' << vid;
}
os << ")";
// Register handle data type
// TODO(tvm-team): consider simply keep type info in the
// type annotation(via a normalizing rewriting).
for (const auto &param : func->params) {
if (auto *ptr = param->type_annotation.as<PointerTypeNode>()) {
if (auto *prim = ptr->element_type.as<PrimTypeNode>()) {
RegisterHandleType(param.get(), prim->dtype);
}
}
}
}
void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar,
const PrimFunc &f) {
// If the function has already been forward-declared, this is a
// no-op.
CodeGenC::DeclareFunction(gvar, f);
// clear previous generated state.
this->InitFuncState(f);
// reserve keywords
......@@ -1646,7 +1711,8 @@ void CodeGenTileLangCUDA::AddFunction(const PrimFunc &f) {
this->PrintFuncPrefix(stream);
CodeGenC::PrintType(f->ret_type, stream);
this->PrintExtraAttrs(f, stream);
this->PrintExtraAttrs(f);
this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";
for (size_t i = 0; i < f->params.size(); ++i) {
......
......@@ -26,7 +26,7 @@ public:
std::string Finish();
// override behavior
void PrintFuncPrefix(std::ostream &os) final;
void PrintExtraAttrs(const PrimFunc &f, std::ostream &os) final;
void PrintExtraAttrs(const PrimFunc &f);
void VisitStmt_(const ForNode *op) final;
void PrintStorageSync(const CallNode *op) final;
void PrintStorageScope(const std::string &scope,
......@@ -54,7 +54,9 @@ public:
void VisitStmt_(const AttrStmtNode *op) final;
// Override this as a work around for __grid_constant__ parameter
void AddFunction(const PrimFunc &f);
void AddFunction(const GlobalVar &gvar, const PrimFunc &f);
void PrintFunctionSignature(const String &function_name, const PrimFunc &func,
std::ostream &os);
protected:
virtual std::string GetBufferRef(DataType t, const BufferNode *buffer,
......
......@@ -47,10 +47,11 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenTileLangCUDA: Can only take PrimFunc";
auto gvar = Downcast<GlobalVar>(kv.first);
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
cg.AddFunction(f);
cg.AddFunction(gvar, f);
}
std::string code = cg.Finish();
......@@ -78,10 +79,11 @@ String BuildTLDebug(IRModule mod, Target target) {
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenTileLangCUDA: Can only take PrimFunc";
auto gvar = Downcast<GlobalVar>(kv.first);
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
cg.AddFunction(f);
cg.AddFunction(gvar, f);
}
std::string code = cg.Finish();
......
......@@ -229,7 +229,7 @@ TL_DEVICE void fence_proxy_async() {
TL_DEVICE void syncthreads_partial(uint64_t &smem_barrier) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
uint64_t state;
uint64_t state = 0;
asm volatile("{\n"
".reg .pred P1;\n"
"mbarrier.arrive.shared.b64 %1, [%0];\n"
......
......@@ -229,7 +229,7 @@ TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
}
template <int num_mma> TL_DEVICE void wait_wgmma() {
warpgroup_wait<num_mma>();
cute::warpgroup_wait<num_mma>();
}
template <int NumMmaThreads> TL_DEVICE void warp_scheduler_barrier_sync() {
......
......@@ -42,6 +42,7 @@ public:
PrimFuncNode *fptr = f.CopyOnWrite();
LowerHopperIntrin substituter;
fptr->body = substituter.VisitStmt(f->body);
Map<String, Array<PrimExpr>> init_desc_arg_map;
for (auto [call, var] : substituter.desc_map_) {
// Should allocate 128 bytes for TensorMap on stack
Call alloc_desc = Call(DataType::Handle(), builtin::tvm_stack_alloca(),
......@@ -57,11 +58,14 @@ public:
init_desc_args.push_back(var);
init_desc_args.insert(init_desc_args.end(), call->args.begin(),
call->args.end());
// add to function attribute
Call init_desc =
Call(DataType::Handle(), builtin::tvm_call_packed(), init_desc_args);
fptr->body =
LetStmt(var, alloc_desc, SeqStmt({Evaluate(init_desc), fptr->body}));
init_desc_arg_map.Set(var->name_hint, init_desc_args);
}
f = WithAttr(std::move(f), "tma_descriptor_args", init_desc_arg_map);
return f;
}
......
......@@ -867,9 +867,53 @@ private:
friend class WarpSpecializedRewriter;
};
class ThreadTagChecker : public StmtExprVisitor {
public:
static bool HasOnlyThreadIdxX(const PrimFunc &f) {
ThreadTagChecker checker;
checker(f->body);
return checker.is_valid_;
}
private:
void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) {
auto iter_var = Downcast<IterVar>(op->node);
if (iter_var->thread_tag.length() > 0 &&
iter_var->thread_tag != "threadIdx.x") {
is_valid_ = false;
}
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const ForNode *op) final {
if (op->kind == ForKind::kThreadBinding) {
ICHECK(op->thread_binding.defined());
String thread_tag = op->thread_binding.value()->thread_tag;
if (thread_tag.length() > 0 && thread_tag != "threadIdx.x") {
is_valid_ = false;
}
}
StmtExprVisitor::VisitStmt_(op);
}
bool is_valid_ = true;
};
class WarpSpecializedRewriter : public StmtExprMutator {
public:
static PrimFunc Substitute(PrimFunc f) {
// Check if function only uses threadIdx.x before proceeding
if (!ThreadTagChecker::HasOnlyThreadIdxX(f)) {
LOG(WARNING) << "WarpSpecialize will be disabled because the program "
"uses thread tags other than threadIdx.x\n"
<< "If you want to use warp specialization, please refactor "
"your program to use threadIdx.x only";
// Return original function unchanged if other thread tags are found
return f;
}
auto T = WarpSpecializedRewriter();
T.buffer_lca_ = DetectBufferAccessLCA(f);
for (auto [buffer, _] : T.buffer_lca_)
......
......@@ -361,7 +361,9 @@ def assert_tl_matmul_block_all_dynamic_correctness(
num_stages,
num_threads,
)
mod, params = TL.lower(program)
kernel = tilelang.compile(program)
if trans_A:
A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype))
else:
......@@ -372,8 +374,7 @@ def assert_tl_matmul_block_all_dynamic_correctness(
B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, B, C)
kernel(A, B, C)
def ref_program(A, B):
import torch
......@@ -414,8 +415,6 @@ def test_assert_tl_matmul_block_all_dynamic():
"float16", 64, 64, 32)
assert_tl_matmul_block_all_dynamic_correctness(36, 128, 128, False, False, "float16", "float16",
"float16", 64, 64, 32)
assert_tl_matmul_block_all_dynamic_correctness(36, 115, 103, False, False, "float16", "float16",
"float16", 64, 64, 32)
if __name__ == "__main__":
......
......@@ -176,7 +176,8 @@ def matmul_fp16xfp4(M,
return main
return kernel_func(block_M=64, block_N=64, block_K=64, num_stages=1, threads=128)
return kernel_func(
block_M=block_M, block_N=block_N, block_K=block_K, num_stages=num_stages, threads=threads)
def ref_program(A, qB):
......@@ -640,4 +641,6 @@ def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4():
if __name__ == "__main__":
tilelang.testing.main()
# tilelang.testing.main()
assert_simple_impl_float16xfp4_gemm(256, 256, 256, "float16", "float16", "float32", 64, 64, 64,
1, 128)
......@@ -167,10 +167,6 @@ def test_gemm_f32f32f32_nn():
)
def test_gemm_i8i8i32_nn():
run_gemm(512, 1024, 768, False, False, "int8", "int8", "int32", 128, 128, 64)
def test_gemm_f16f16f16_tn():
run_gemm(
512,
......
......@@ -10,8 +10,8 @@ from tilelang import tvm as tvm
from tvm.target import Target
from tvm.relay import TensorType
from tvm import tir
from .wrapper import TLWrapper
from .libgen import LibraryGenerator
from tilelang.jit.adapter.wrapper import TLWrapper
from tilelang.jit.adapter.libgen import LibraryGenerator
from tilelang.utils.target import determine_target
from tilelang.utils.language import retrieve_func_from_module
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import re
from typing import Union, Optional
from tilelang import tvm as tvm
from tvm import IRModule, tir
from tvm.target import Target
from tilelang.engine.lower import (
is_device_call,
determine_target,
canon_target_host,
)
from tilelang.engine.phase import (
LowerAndLegalize,
OptimizeForTarget,
)
def match_global_kernel(source: str) -> int:
pattern = r"__global__\s+void\s+[__launch_bounds__\(\d+\)\s+]\w+"
matched = re.findall(pattern, source)
assert len(matched) >= 1 # may have statement before kernel
return source.index(matched[0])
def is_cuda_target(target: Target) -> bool:
return target.kind.name == "cuda"
def is_hip_target(target: Target) -> bool:
return target.kind.name == "hip"
def get_annotated_device_mod(
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
target: Union[str, Target] = "auto",
target_host: Optional[Union[str, Target]] = None,
) -> "IRModule":
mod = func_or_mod
if isinstance(func_or_mod, tir.PrimFunc):
func = func_or_mod
mod = tvm.IRModule({func.attrs["global_symbol"]: func})
if isinstance(target, str):
target = determine_target(target)
target_host = canon_target_host(target, target_host)
target_host = tvm.target.Target.canon_target(target_host)
target = tvm.target.Target(target, target_host)
mod = LowerAndLegalize(mod, target)
mod = OptimizeForTarget(mod, target)
device_mod = tir.transform.Filter(is_device_call)(mod)
return device_mod
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from abc import ABC, abstractmethod
from tilelang import tvm as tvm
from typing import Optional, List, Dict, Union
from tvm import IRModule
from tvm.target import Target
from .utils import match_global_kernel, is_cuda_target, is_hip_target, get_annotated_device_mod
import re
import logging
PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY = """
cudaFuncSetAttribute({}, cudaFuncAttributeMaxDynamicSharedMemorySize, {});
"""
PREDEF_INIT_FUNC = """
extern "C" void init() {{
{}
}}
"""
PREDEF_HOST_FUNC = """
extern "C" void call({}) {{
{}
}}
"""
class BaseWrapper(ABC):
@abstractmethod
def wrap(self, *args, **kwargs):
raise NotImplementedError
logger = logging.getLogger(__name__)
class TLCUDASourceWrapper(object):
_TYPE_MAP = {
"float32": "float",
"float16": "half_t",
"bfloat16": "bfloat16_t",
"e4m3_float8": "__nv_fp8_e4m3",
"e5m2_float8": "__nv_fp8_e5m2",
"float64": "double",
"int64": "int64_t",
"int32": "int",
"uint32": "unsigned int",
"bool": "int8_t",
"int8": "int8_t",
"uint8": "uint8_t",
"int16": "int16_t",
"uchar": "uint8_t",
}
backend = "tl"
def __init__(self, scheduled_ir_module: IRModule, source: str, target: Target):
self.mod = scheduled_ir_module
self.target = target
self.source = source
self.function_name: Optional[str] = None
self.dynamic_smem_buf: Optional[int] = None
self.block_info: Union[List[int], Dict] = [1, 1, 1]
self.grid_info: Union[List[int], Dict] = [1, 1, 1]
self.parse_source_information()
self.srcpath: Optional[str] = None
self.libpath: Optional[str] = None
self.lib_code: Optional[str] = self.update_lib_code(source)
def parse_source_information(self):
device_mod = get_annotated_device_mod(self.mod, self.target)
assert (len(device_mod.functions) == 1
), "Only support one function in the module for static shape kernel."
for g_var, func in device_mod.functions.items():
self.function_name = g_var.name_hint
attrs = func.attrs
if "dyn_shared_memory_buf" in attrs:
self.dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"])
if "thread_extent" in attrs:
thread_extent = attrs["thread_extent"]
for tag, extent in thread_extent.items():
if "threadIdx" in tag:
self.block_info["xyz".index(tag[-1])] = extent
elif "blockIdx" in tag:
self.grid_info["xyz".index(tag[-1])] = extent
def get_dynamic_symbolic_set(self, prim_func):
# Determine the set of dynamic symbols used in the function
dynamic_symbolic_set: List[str] = []
for param in prim_func.params:
buffer = prim_func.buffer_map[param]
for dim in buffer.shape:
if isinstance(dim, tvm.tir.Var) and (dim.name not in dynamic_symbolic_set):
dynamic_symbolic_set.append(dim.name)
return dynamic_symbolic_set
def get_cuda_init_func(self):
# Initialize an empty string for the CUDA function call
call_str = """"""
# If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call
if self.dynamic_smem_buf is not None:
call_str = (
PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY.format(self.function_name,
self.dynamic_smem_buf))
# Format the initialization function using the call_str
init_funcs = PREDEF_INIT_FUNC.format(call_str)
return init_funcs
def update_lib_code(self, code: str):
# Update the library code with the given code string
self.lib_code = code
# Find the index of the global kernel function in the code
index = match_global_kernel(code)
# Extract the declaration of the function starting from the found index
declaration = code[index:].split(";")[0]
function_name = self.function_name
# Get the CUDA initialization function
init_func = self.get_cuda_init_func()
# Locate the opening brace of the function to insert arguments
index = code.index("{", index)
function_args = []
# Populate the function arguments from the primary function's parameters and buffers
for param in self.prim_func.params:
buffer = self.prim_func.buffer_map[param]
function_args.append({
"name": buffer.name,
"type": self._TYPE_MAP[buffer.dtype] + "* __restrict__",
})
dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func)
# Add dynamic symbolic parameters as integers to the function arguments
for dyn_sym in dynamic_symbolic_set:
function_args.append({"name": dyn_sym, "type": "int"})
function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},)
# Format the function arguments for declaration
def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args])
def func_call_args(s, function_args):
# Extract the function call arguments matching the function definition
pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)"
matches = re.findall(pattern, s)
call_args = []
for match in matches:
for arg in function_args:
if arg["name"] == match:
call_args.append(match)
return call_args
call_args = ", ".join(func_call_args(declaration, function_args))
block_info, grid_info = self.block_info, self.grid_info
def legalize_c(p):
# Convert TIR expressions to legal C expressions
# Directly convert to string since the special case handling
# does not alter the string representation for `tvm.tir.Var` and `IntImm`.
# Replace Python's floor division operator with C's division operator
if isinstance(p, tvm.tir.IntImm):
p = int(p)
return str(p).replace("//", "/")
# Prepare the block and grid dimensions for the CUDA kernel launch
block_str = "dim3({}, {}, {})".format(
legalize_c(block_info[0]),
legalize_c(block_info[1]),
legalize_c(block_info[2]),
)
grid_str = "dim3({}, {}, {})".format(
legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2]))
# Determine the shared memory size, defaulting to 0 if not specified
smem_str = 0 if self.dynamic_smem_buf is None else self.dynamic_smem_buf
# Format the CUDA kernel launch string
call_str = ""
if len(dynamic_symbolic_set) != 0:
call_str += "if ({} == 0) return; \n\t\t".format(list(dynamic_symbolic_set)[0])
else:
call_str += ""
call_str += "{}<<<{}, {}, {}, stream>>>({});".format(function_name, grid_str, block_str,
smem_str, call_args)
# Create the host function wrapper for the CUDA kernel
host_func = PREDEF_HOST_FUNC.format(def_args, call_str)
# Combine the source, initialization function, and host function to form the complete library code
lib_code = self.source + init_func + host_func
return lib_code
@property
def prim_func(self):
if len(self.mod.get_global_vars()) == 1:
return self.mod[self.mod.get_global_vars()[0]]
elif "main" in self.mod:
return self.mod["main"]
else:
for _, function in self.mod.functions_items():
attr = function.attrs
if "tir.is_global_func" in attr and attr["tir.is_global_func"]:
return function
raise ValueError("Cannot find primary function in the module.")
class TLHIPSourceWrapper(TLCUDASourceWrapper):
def __init__(self, scheduled_ir_module: IRModule, source: str, target: Target):
super().__init__(scheduled_ir_module, source, target)
def get_hip_init_func(self):
# Initialize an empty string for the CUDA function call
call_str = """"""
# If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call
if self.dynamic_smem_buf is not None:
call_str = PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY.format(self.function_name,
self.dynamic_smem_buf)
# Format the initialization function using the call_str
init_funcs = PREDEF_INIT_FUNC.format(call_str)
return init_funcs
def get_stream_type(self, function_args):
function_args.append({"name": "stream=hipStreamDefault", "type": "hipStream_t"},)
class TLWrapper(BaseWrapper):
def __init__(self, target: Target):
super().__init__()
self.scheduled_ir_module = None
self.target = target
self.lib = None
def assign_optimized_module(self, scheduled_ir_module: IRModule):
self.scheduled_ir_module = scheduled_ir_module
# Get Scheduled Rt Module and return source to be compiled
def wrap(self, c_source: str):
assert self.scheduled_ir_module is not None, "Please assign optimized module first."
if is_cuda_target(self.target):
wrapper_class = TLCUDASourceWrapper
elif is_hip_target(self.target):
wrapper_class = TLHIPSourceWrapper
else:
raise ValueError(f"Unsupported platform: {self.arch.platform}")
wrapper = wrapper_class(self.scheduled_ir_module, c_source, self.target)
return wrapper.lib_code
......@@ -9,8 +9,8 @@ from tilelang import tvm as tvm
from tvm.target import Target
from tvm.relay import TensorType
from tvm import tir
from .wrapper import TLWrapper
from .libgen import LibraryGenerator
from tilelang.jit.adapter.wrapper import TLWrapper
from tilelang.jit.adapter.libgen import LibraryGenerator
from tilelang.utils.target import determine_target
from tilelang.utils.language import retrieve_func_from_module
from tilelang.contrib.cc import get_cplus_compiler
......@@ -175,7 +175,12 @@ class CythonKernelAdapter(BaseKernelAdapter):
self.lib_generator.update_lib_code(self.wrapped_source)
self.lib_generator.compile_lib()
self.lib = self.lib_generator.load_lib()
self.lib.init()
try:
self.lib.init()
except Exception as e:
raise Exception(
f"Failed to initialize the compiled library for {self.target}: {e}") from e
self.cython_wrapper = CythonKernelWrapper(self.dynamic_symbolic_map, self.result_idx,
self.params, self.lib)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Optional
from .utils import is_cuda_target, is_hip_target
from tilelang import tvm as tvm
from tilelang.contrib.nvcc import get_target_compute_version
from tvm.target import Target
import ctypes
import os
import tempfile
import subprocess
import logging
from tilelang.env import TILELANG_TEMPLATE_PATH, CUTLASS_INCLUDE_DIR
logger = logging.getLogger(__name__)
class LibraryGenerator(object):
srcpath: Optional[str] = None
libpath: Optional[str] = None
lib_code: Optional[str] = None
def __init__(self, target: Target):
self.target = target
def update_lib_code(self, lib_code: str):
self.lib_code = lib_code
# Assume currently we only support CUDA compilation
def load_lib(self):
return ctypes.CDLL(self.libpath)
def compile_lib(self, timeout: float = None, with_tl: bool = True):
target = self.target
if is_cuda_target(target):
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False)
compute_version = "".join(get_target_compute_version(target).split("."))
libpath = src.name.replace(".cu", ".so")
command = [
"nvcc",
"-std=c++17",
"-Xcudafe",
"--diag_suppress=177",
"--compiler-options",
"'-fPIC'",
"-lineinfo",
"--shared",
src.name,
"-lcuda",
"-gencode",
f"arch=compute_{compute_version},code=sm_{compute_version}",
]
elif is_hip_target(target):
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False)
libpath = src.name.replace(".cpp", ".so")
command = [
"hipcc",
"-std=c++17",
"-fPIC",
"--shared",
src.name,
]
else:
raise ValueError(f"Unsupported target: {target}")
if with_tl:
command += [
"-I" + TILELANG_TEMPLATE_PATH,
"-I" + CUTLASS_INCLUDE_DIR,
]
command += ["-diag-suppress=20013"]
command += ["-o", libpath]
src.write(self.lib_code)
src.flush()
try:
ret = subprocess.run(command, timeout=timeout)
except subprocess.TimeoutExpired:
logger.warning(f"Compilation Timeout! {command}")
return None
if ret.returncode != 0:
logger.warning(f"Compilation Failed! {command}")
return None
self.srcpath = src.name
self.libpath = libpath
def remove_lib(self):
if self.libpath:
os.remove(self.libpath)
self.libpath = None
def get_source_path(self):
return self.srcpath
def get_lib_path(self):
return self.libpath
def set_lib_path(self, libpath):
self.libpath = libpath
def set_src_path(self, srcpath):
self.srcpath = srcpath
......@@ -35,11 +35,14 @@ class LibraryGenerator(object):
if is_cuda_target(target):
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False)
compute_version = "".join(get_target_compute_version(target).split("."))
if compute_version == "90":
compute_version = "90a"
libpath = src.name.replace(".cu", ".so")
command = [
"nvcc",
"-std=c++17",
"-std=c++17",
"-w", # Disable all warning messages
"-Xcudafe",
"--diag_suppress=177",
"--compiler-options",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment