Unverified Commit 5c11d245 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Merge bulk copy into copy and improve layout inference for bulk copy (#746)

* [Refactor] Merge bulk copy into copy and refactor layout inference for bulk copy

* Deleted the `bulk_copy` operator implementation and its header file as it is no longer needed.
* Introduced a new function `cuTensorMapType()` to return the data type for CUDA tensor mapping.
* Updated related files to reflect these changes, ensuring that the codebase remains clean and maintainable.

* lint fix

* Fix typos in intrinsic names and remove unused print statement in block_sparse_attn_tilelang.py. Updated references from `ptx_ldmatirx` to `ptx_ldmatrix` across multiple files for consistency.

* remove bulk copy

* Refactor copy and atomic add operations to support TMA lower configuration

- Updated `GetCopyInst` to accept a `disable_tma_lower` parameter, allowing for conditional usage of TMA in bulk load/store operations.
- Modified `Lower` method in `Copy` to incorporate the new TMA configuration.
- Refactored `AtomicAdd::Lower` to streamline layout inference and vectorization logic.
- Removed unused `disable_tma_lower` field from `LowerArgs` structure for clarity.
- Enhanced atomic add vectorization by replacing the buggy implementation with a more robust loop vectorization approach.

* Enhance TMA bulk copy logic in `LowerBulkCopy` method

- Added a condition to set `desc.swizzle` to `CU_TENSOR_MAP_SWIZZLE_NONE` when `shared_layout` matches `linear_layout`, improving clarity in layout handling.
- Updated warning log to provide more detailed information about fallback scenarios, including source and destination buffer names and shapes, enhancing debugging capabilities.

* lint fix

* Remove fallback logging for non-swizzled global layout in `LowerBulkCopy` method to streamline the bulk copy logic. This change enhances code clarity by eliminating unnecessary warning messages related to inner box dimensions.

* Enhance reshape kernel compilation in `run_reshape` and `run_reshape_smem_1d_2_2d` functions

- Updated the `tl.compile` method to include `pass_configs` that disable TMA lower and warp specialization, addressing shared memory layout transformation limitations.
- Added TODO comments to indicate the need for further improvements in shared memory handling.

* Update `native_sparse_attention` function to include TMA configuration options

- Added `pass_configs` to the JIT decorator to disable TMA lower and warp specialization, addressing potential issues with shared memory layout transformations.
- Updated comments to clarify modifications in tensor shapes for inference, specifically setting `q` sequence length to 1.

* Refactor JIT decorator formatting in `native_sparse_attention` function

- Improved readability by reformatting the JIT decorator parameters for `native_sparse_attention`, ensuring consistent style across the codebase.
- No functional changes were made; this update focuses on code clarity and maintainability.

* Enhance thread management and logging in TileLang compilation

- Added a method to check if printing is enabled during compilation, improving control over logging behavior.
- Updated the JIT kernel class to utilize the new method for logging compilation status, ensuring consistent and clear output.
- Added comments to clarify the purpose of changes and improve code readability.

* Add warp specialization scope and refactor register management in TileLang

- Introduced a new constant `kWarpSpecializationScope` in `builtin.h` for better attribute management.
- Removed the `SetMaxNRegCollector` class and its related logic from `warp_specialized_rewriter.cc`, streamlining the warp specialization process.
- Added functions `annotate_producer_reg_dealloc` and `annotate_consumer_reg_alloc` in `builtin.py` to facilitate register management.
- Implemented `AnnotateWarpGroupRegAlloc` in `__init__.py` to inject register allocation calls into warp-specialized functions, enhancing the overall register handling in the compilation process.

* Refactor test for InjectSetMaxNReg pass in TileLang

- Improved readability by restructuring conditional checks and assertions in the test cases.
- Enhanced clarity in the collection of `set_max_nreg` calls by simplifying the logic.
- Ensured consistent formatting and spacing throughout the test functions for better maintainability.

* Enhance bulk copy and store checks in `Copy` class

- Updated scope validation for source and destination tensors in `CheckBulkLoad` and `CheckBulkStore` methods to include both `shared.dyn` and `shared` as valid options.
- Modified `CheckLDSMCopy` and `CheckSTSMCopy` methods to accommodate the new scope validation, ensuring compatibility with shared memory configurations.
- Improved logging in `LowerBulkCopy` to provide clearer warnings regarding unsupported swizzle layouts, including source and destination names for better debugging.

* lint fix
parent cb37bfef
......@@ -104,6 +104,13 @@ bool TargetHasStmatrix(Target target) {
return arch >= 90;
}
bool TargetHasBulkCopy(Target target) {
if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target);
return arch >= 90;
}
int TargetGetWarpSize(Target target) {
int res = 32;
if (TargetIsCDNA(target))
......
......@@ -25,6 +25,7 @@ bool TargetIsCDNA(Target target);
bool TargetHasAsyncCopy(Target target);
bool TargetHasLdmatrix(Target target);
bool TargetHasStmatrix(Target target);
bool TargetHasBulkCopy(Target target);
int TargetGetWarpSize(Target target);
} // namespace tl
......
/*!
* \file annotate_warp_group_reg_alloc.cc
* \brief Annotate warp group reg alloc for warp specialization
*/
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_set>
#include <vector>
#include "../op/builtin.h"
#include "tir/transforms/ir_utils.h"
namespace tvm {
namespace tl {
using namespace tir;
class SetMaxNRegCollector : public StmtExprVisitor {
public:
static Array<IntImm> Collect(const PrimFunc &f) {
SetMaxNRegCollector collector;
collector(f->body);
return collector.has_no_set_max_nreg_
? Array<IntImm>({IntImm(DataType::Int(32), -1),
IntImm(DataType::Int(32), -1)})
: collector.nreg_;
}
private:
void VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(set_max_nreg())) {
int reg_hint = call->args[0].as<IntImmNode>()->value;
int is_inc = call->args[1].as<IntImmNode>()->value;
ICHECK(reg_hint <= 240 && reg_hint >= 24)
<< "Invalid reg hint: " << reg_hint;
ICHECK(is_inc == 0 || is_inc == 1) << "Invalid is_inc: " << is_inc;
// producer should decrease register hint while consumer should increase
// register hint
nreg_.Set(is_inc, IntImm(DataType::Int(32), reg_hint));
} else if (call->op.same_as(no_set_max_nreg())) {
has_no_set_max_nreg_ = true;
}
}
StmtExprVisitor::VisitStmt_(op);
}
Array<IntImm> nreg_{IntImm(DataType::Int(32), 0),
IntImm(DataType::Int(32), 0)};
bool has_no_set_max_nreg_ = false;
};
class SetMaxNRegInjector : public StmtExprMutator {
public:
static PrimFunc Inject(PrimFunc f) {
auto T = SetMaxNRegInjector();
T.nreg_ = SetMaxNRegCollector::Collect(f);
f.CopyOnWrite()->body = T(f->body);
return f;
}
private:
Stmt VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(set_max_nreg()) ||
call->op.same_as(no_set_max_nreg())) {
// Remove the original set_max_nreg calls as they will be re-inserted
// at appropriate locations
return Evaluate(0);
}
}
return StmtExprMutator::VisitStmt_(op);
}
Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent &&
Downcast<IterVar>(op->node)->thread_tag == "threadIdx.x") {
thread_iv_ = Downcast<IterVar>(op->node);
need_update_thread_extent_ = false;
AttrStmt attr_stmt = Downcast<AttrStmt>(StmtExprMutator::VisitStmt_(op));
if (need_update_thread_extent_) {
thread_iv_.CopyOnWrite()->dom = {0, updated_thread_extent_.value()};
attr_stmt.CopyOnWrite()->node = thread_iv_;
attr_stmt.CopyOnWrite()->value = updated_thread_extent_.value();
}
thread_iv_ = {};
return attr_stmt;
} else if (op->attr_key == attr::kWarpSpecializationScope) {
auto if_then_else = Downcast<IfThenElse>(op->body);
if (!if_then_else.defined()) {
return StmtExprMutator::VisitStmt_(op);
}
auto producer_body = if_then_else->then_case;
Optional<Stmt> consumer_body = if_then_else->else_case;
ICHECK(consumer_body.defined()) << "Consumer body is undefined";
int dec_reg = nreg_[0].as<IntImmNode>()->value;
int inc_reg = nreg_[1].as<IntImmNode>()->value;
auto inc_reg_stmt = Evaluate(0);
auto dec_reg_stmt = Evaluate(0);
// Only inject if we have valid register hints and no SIMT copy
// For now, we assume no SIMT copy detection is available here
// TODO: Add SIMT copy detection if needed
bool has_simt_copy = false; // Placeholder
if (dec_reg >= 0 && inc_reg >= 0 && !has_simt_copy) {
inc_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),
{inc_reg == 0 ? 240 : inc_reg, 1}));
dec_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),
{dec_reg == 0 ? 24 : dec_reg, 0}));
}
// Inject register setting statements
Array<Stmt> producer_stmts;
producer_stmts.push_back(dec_reg_stmt);
producer_stmts.push_back(producer_body);
auto new_producer_body = SeqStmt(producer_stmts);
Array<Stmt> consumer_stmts;
consumer_stmts.push_back(inc_reg_stmt);
consumer_stmts.push_back(consumer_body.value());
auto new_consumer_body = SeqStmt(consumer_stmts);
auto new_if_stmt = IfThenElse(if_then_else->condition, new_producer_body,
new_consumer_body);
auto new_attr = AttrStmt(op->node, op->attr_key, op->value, new_if_stmt);
return new_attr;
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
Array<IntImm> nreg_;
IterVar thread_iv_;
Optional<PrimExpr> updated_thread_extent_;
bool need_update_thread_extent_ = false;
};
using namespace tir::transform;
tvm::transform::Pass AnnotateWarpGroupRegAlloc() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) -> PrimFunc {
return SetMaxNRegInjector::Inject(f);
};
return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateWarpGroupRegAlloc", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.AnnotateWarpGroupRegAlloc",
AnnotateWarpGroupRegAlloc);
});
} // namespace tl
} // namespace tvm
......@@ -203,12 +203,9 @@ private:
Stmt body = Substitute(fnode->body, vmap);
return For(outer_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
} else {
return fnode;
}
} else {
return ret;
}
return ret;
}
PrimExpr VisitExpr_(const CallNode *node) final {
......
......@@ -57,7 +57,7 @@ public:
void VisitStmt_(const EvaluateNode *op) final {
Proxy proxy = Proxy::kAsync;
if (auto call = op->value.as<CallNode>()) {
if (call->op.same_as(ptx_ldmatirx()) ||
if (call->op.same_as(ptx_ldmatrix()) ||
call->op.same_as(ptx_stmatrix())) {
proxy = Proxy::kGeneric;
}
......
......@@ -11,7 +11,6 @@
#include <tvm/tir/transform.h>
#include "../op/builtin.h"
#include "../op/bulk_copy.h"
#include "../runtime/runtime.h"
namespace tvm {
......
......@@ -10,7 +10,6 @@
#include <tvm/tir/transform.h>
#include "../op/builtin.h"
#include "../op/bulk_copy.h"
#include "../runtime/runtime.h"
namespace tvm {
......
......@@ -430,11 +430,6 @@ private:
return workspace.access_ptr(2); // write
};
// Get pass config `tl.disable_tma_lower`
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
Optional<Bool> opt_disable_tma_lower =
ctxt->GetConfig(kDisableTMALower, Optional<Bool>());
bool disable_tma_lower = opt_disable_tma_lower.value_or(Bool(false));
Range thread_bounds;
if (analyzer_->const_int_bound.IsBound(thread_var_->var)) {
......@@ -449,10 +444,10 @@ private:
thread_bounds = Range::FromMinExtent(0, 1);
}
auto lowered = tile_op->Lower(
LowerArgs{target_, thread_bounds, thread_var_->var, callback,
layout_map_, buffer_remap_, disable_tma_lower},
analyzer_);
auto lowered =
tile_op->Lower(LowerArgs{target_, thread_bounds, thread_var_->var,
callback, layout_map_, buffer_remap_},
analyzer_);
return IRMutatorWithAnalyzer::VisitStmt(lowered);
}
......
......@@ -10,7 +10,6 @@
#include <tvm/tir/transform.h>
#include "../op/builtin.h"
#include "../op/bulk_copy.h"
#include "../runtime/runtime.h"
namespace tvm {
......
......@@ -1146,42 +1146,6 @@ private:
bool has_simt_copy_ = false;
};
class SetMaxNRegCollector : public StmtExprVisitor {
public:
static Array<IntImm> Collect(const PrimFunc &f) {
SetMaxNRegCollector collector;
collector(f->body);
return collector.has_no_set_max_nreg_
? Array<IntImm>({IntImm(DataType::Int(32), -1),
IntImm(DataType::Int(32), -1)})
: collector.nreg_;
}
private:
void VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(set_max_nreg())) {
int reg_hint = call->args[0].as<IntImmNode>()->value;
int is_inc = call->args[1].as<IntImmNode>()->value;
ICHECK(reg_hint <= 240 && reg_hint >= 24)
<< "Invalid reg hint: " << reg_hint;
ICHECK(is_inc == 0 || is_inc == 1) << "Invalid is_inc: " << is_inc;
// producer should decrease register hint while consumer should increase
// register hint
nreg_.Set(is_inc, IntImm(DataType::Int(32), reg_hint));
} else if (call->op.same_as(no_set_max_nreg())) {
has_no_set_max_nreg_ = true;
}
}
StmtExprVisitor::VisitStmt_(op);
}
Array<IntImm> nreg_{IntImm(DataType::Int(32), 0),
IntImm(DataType::Int(32), 0)};
bool has_no_set_max_nreg_ = false;
};
class WarpSpecializedRewriter : public StmtExprMutator {
public:
WarpSpecializedRewriter(bool disable_warp_specialized,
......@@ -1202,7 +1166,6 @@ public:
auto T = WarpSpecializedRewriter(disable_warp_specialized,
disable_shuffle_elect);
T.nreg_ = SetMaxNRegCollector::Collect(f);
T.buffer_lca_ = DetectBufferAccessLCA(f);
for (auto [buffer, _] : T.buffer_lca_)
T.buffer_data_to_buffer_.Set(buffer->data, buffer);
......@@ -1229,16 +1192,6 @@ private:
}
}
Stmt VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(set_max_nreg()) ||
call->op.same_as(no_set_max_nreg())) {
return Evaluate(0);
}
}
return StmtExprMutator::VisitStmt_(op);
}
// If users define a thread binding, we will replace the thread binding with
// threadIdx.x We require the thread binding is threadIdx.x, and the extent is
// the same as the thread extent
......@@ -1334,22 +1287,6 @@ private:
if (!marker.HasSimtCopy())
producer_thread_extent = 128;
// TODO: estimate the correct reg usage.
int dec_reg = nreg_[0].as<IntImmNode>()->value;
int inc_reg = nreg_[1].as<IntImmNode>()->value;
auto inc_reg_stmt = Evaluate(0);
auto dec_reg_stmt = Evaluate(0);
if (dec_reg >= 0 && inc_reg >= 0 && !marker.HasSimtCopy()) {
inc_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),
{inc_reg == 0 ? 240 : inc_reg, 1}));
dec_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),
{dec_reg == 0 ? 24 : dec_reg, 0}));
}
producer_code = SeqStmt({dec_reg_stmt, producer_code});
consumer_code = SeqStmt({inc_reg_stmt, consumer_code});
updated_thread_extent_ = consumer_thread_extent + producer_thread_extent;
producer_code = ThreadIdxRewriter::Rewrite(
......@@ -1382,7 +1319,7 @@ private:
// Add an attr here to handle the partial thread count in ThreadSync pass.
Array<IntImm> ws_partition = {Downcast<IntImm>(producer_thread_extent),
Downcast<IntImm>(consumer_thread_extent)};
body = AttrStmt(ws_partition, "kWarpSpecializationScope", 0, body);
body = AttrStmt(ws_partition, attr::kWarpSpecializationScope, 0, body);
block.CopyOnWrite()->body = SeqStmt({init_barrier, body});
block_realize.CopyOnWrite()->block = block;
......@@ -1399,17 +1336,26 @@ private:
bool need_update_thread_extent_ = false;
bool disable_warp_specialized_ = false;
bool disable_shuffle_elect_ = false;
Array<IntImm> nreg_;
bool only_has_wgmma_ = false;
};
class WarpSpecializedDetector : public IRVisitorWithAnalyzer {
public:
// return true means this aws will be disabled
static bool Detect(Stmt stmt, bool skip_thread_partition = false) {
WarpSpecializedDetector detector;
detector.VisitStmt(stmt);
return detector.has_warp_specialization_ ||
(detector.has_tma_op_ && detector.has_mbarrier_op_);
if (detector.has_warp_specialization_) {
LOG(WARNING) << "Auto warp specialization will be disabled because warp "
"specialization is manually enabled";
return true;
}
if (detector.has_tma_op_ && detector.has_mbarrier_op_) {
LOG(WARNING) << "Auto warp specialization will be disabled because TMA "
"and mbarrier are both present";
return true;
}
return false;
}
WarpSpecializedDetector() {
......
import torch
import torch.backends
import tilelang.testing
from tilelang import tvm as tvm
from tvm import DataType, tir
import tilelang.language as T
tilelang.testing.set_random_seed(0)
def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4
assert dtype == "float16"
assert val.dtype == "uint8"
# e_f4 == 0 -> e_f16 = 0
# e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2
# s1e2n1
mask = tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = f4 & tir.const(7, "uint16")
e_f16 = e_f4 | tir.const(8, "uint16")
val_f16 = tir.reinterpret(
"float16",
((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")).astype("uint16"))
# return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16)
return val_f16
def torch_convert(tensor):
def print_bit(name, val):
val_cpu = val.cpu().item()
binary_repr = f'{val_cpu:032b}'
print(name, binary_repr)
def _convert(val, pos):
assert val.dtype == torch.uint8
val = val.view(torch.int8)
mask = (1 << 4) - 1
f4 = ((val >> (pos * 4)) & mask).to(torch.int16)
s = f4 >> 3
e_f4 = f4 & 7
e_f16 = e_f4 | 8
val_f16 = ((e_f16 | (s << 5)) << 10) & 0xFFFF
lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16)
return lower_16_bits.view(torch.float16)
N = tensor.shape[0]
K = tensor.shape[1]
new_tensor = torch.empty(N, K * 2, dtype=torch.float16, device=tensor.device)
for i in range(new_tensor.shape[0]):
for j in range(new_tensor.shape[1]):
new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2)
return new_tensor
def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
B_shape = (N, K // num_elems_per_byte)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
@T.prim_func
def main(
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((N, K), in_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, C[bx * block_N, k * block_K])
return main
def test_fp4_fp16_convert_close():
N, K = 256, 256
block_N, block_K = 64, 64
program = _convert_test(
N,
K,
block_N,
block_K,
"float16",
)
print(program.script())
kernel = tilelang.compile(program, out_idx=[1])
B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8)
tl_out = kernel(B)
ref_out = torch_convert(B)
assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out)
print("Pass")
def matmul_fp16xfp4(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
block_M=64,
block_N=64,
block_K=64,
num_stages=1,
threads=128):
num_bits = 4
def kernel_func(block_M, block_N, block_K, num_stages, threads):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
A_shape = (M, K)
B_shape = (N, K // num_elems_per_byte)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
assert K % (block_K) == 0
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
})
T.clear(Ct_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, Ct_shared)
T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N,
by * block_M:(by + 1) * block_M])
return main
return kernel_func(
block_M=block_M, block_N=block_N, block_K=block_K, num_stages=num_stages, threads=threads)
def ref_program(A, qB):
dtypeC = "float16"
B = torch_convert(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C.transpose(0, 1)
def assert_simple_impl_float16xfp4_gemm(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
block_M=64,
block_N=64,
block_K=64,
num_stages=1,
threads=128):
func = matmul_fp16xfp4(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K,
num_stages, threads)
torch_func = tilelang.compile(func, out_idx=[2])
profiler = torch_func.get_profiler()
profiler.assert_allclose(ref_program)
def test_simple_impl_float16xfp4_gemm():
assert_simple_impl_float16xfp4_gemm(256, 256, 256, "float16", "float16", "float32", 64, 64, 64,
1, 128)
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
num_bits=4,
):
from bitblas.quantization import _tir_packed_to_unsigned_convert
num_elems_per_byte = 8 // num_bits
storage_dtype = "int8"
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
storage_type = str("".join(c for c in storage_dtype if not c.isdigit()))
A_shape = (M, K)
B_shape = (N, K // num_elems_per_byte)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
MAX_TRANSACTION_SIZE_IN_BITS = 128
local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits
local_size_compressed = local_size // num_elems_per_byte
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_local([local_size_compressed], storage_dtype)
B_dequantize_local = T.alloc_local([local_size], in_dtype)
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
tx = T.get_thread_binding()
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
for i in T.serial(block_N * block_K // num_elems_per_byte //
(threads * local_size_compressed)):
for v in T.vectorized(0, local_size_compressed):
index = i * threads * local_size_compressed + tx * local_size_compressed + v
vi = index // (block_K // num_elems_per_byte)
vj = index % (block_K // num_elems_per_byte)
B_local[v] = B_shared[vi, vj]
for v in T.serial(0, local_size):
B_dequantize_local[v] = _tir_packed_to_unsigned_convert(
storage_type, storage_nbit)(
num_bits,
B_local[v // num_elems_per_byte],
v % num_elems_per_byte,
dtype=in_dtype,
)
for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v
vi = index // block_K
vj = index % block_K
B_dequantize_shared[vi, vj] = B_dequantize_local[v]
T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm(
M,
N,
K,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
kernel = tilelang.compile(program, out_idx=[2])
profiler = kernel.get_profiler()
out = profiler.run_once()
assert out is not None
def ref_program(A, qB):
import torch
B = (
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4,
dtype=torch.half).to(torch.half).to(A.device))
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program)
# bitblas currently only support sm80-sm90
@tvm.testing.requires_package("bitblas")
@tilelang.testing.requires_llvm
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
transform_b,
):
from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout as make_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitterWithLadderTransform,)
from bitblas.gpu.intrin.lop3 import decode_i4_to_f16
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
num_bits = 4
num_elems_per_byte = 8 // num_bits
storage_dtype = "int8"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
micro_size_k = 32
# This is a debug config
block_row_warps = 2
block_col_warps = 2
warp_rows = 4
warp_cols = 4
warp_row_tiles = micro_size_x * warp_rows
warp_col_tiles = micro_size_y * warp_cols
shared_scope = "shared.dyn"
# Pipeline Stage
stage = 2
reduce_k = 1
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = 32 if in_dtype == "float16" else 64
chunk = block_K // reduce_k
is_smooth_a = False
can_swizzle = block_K * DataType(in_dtype).bits == 512
apply_pad_a = not (is_smooth_a or can_swizzle)
pad_factor = 8
A_shape = (M, K)
B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y,
micro_size_k // num_elems_per_byte)
A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K)
B_shared_shape = (
block_N // micro_size_y,
block_K // micro_size_k,
micro_size_y,
micro_size_k // num_elems_per_byte,
)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = TensorCoreIntrinEmitterWithLadderTransform(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
reduce_k=reduce_k,
transform_kind_b=transform_b,
num_elems_per_byte=num_elems_per_byte)
vec_load_qb = 16
if block_N * (block_K // reduce_k) // num_elems_per_byte // threads < vec_load_qb:
vec_load_qb = block_N * (block_K // reduce_k) // num_elems_per_byte // threads
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads,
prelude=decode_i4_to_f16) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size), in_dtype)
B_local = T.alloc_local((warp_cols * local_size // num_elems_per_byte), storage_dtype)
B_dequantize_local = T.alloc_local((warp_cols * local_size), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
reduced_accum_res = T.alloc_local(0, accum_dtype)
thread_binding = T.get_thread_binding(0)
rk = T.get_thread_binding(1)
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
})
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, (block_K // reduce_k)):
vk = rk * (block_K // reduce_k) + k
A_shared[i, vk] = A[by * block_M + i, ko * block_K + vk]
# TODO(lei): Layout Inference Pass is not efficient to handle the four dims int8 load
for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte //
(threads * vec_load_qb)):
for v in T.vectorized(0, vec_load_qb):
t = thread_binding
idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v
vkk = idx % (micro_size_k // num_elems_per_byte)
vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y
vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % (
block_K // micro_size_k)
vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y //
(block_K // micro_size_k)) % (
block_N // micro_size_y)
B_shared[vj, vk, vjj,
vkk] = B[bx * (block_N // micro_size_y) + vj,
ko * (block_K // micro_size_k) + vk, vjj, vkk]
for ki in T.serial(0, (block_K // (micro_size_k * reduce_k))):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
rk=rk,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
rk=rk,
)
for j in T.serial(warp_cols):
local_size_b = mma_emitter.local_size_b
T.call_extern('handle', 'decode_i4u_to_f16',
T.address_of(B_local[j * local_size_b // num_elems_per_byte]),
T.address_of(B_dequantize_local[j * local_size_b]), 8)
mma_emitter.mma(A_local, B_dequantize_local, C_local)
if reduce_k > 1:
for n in T.serial(warp_rows * warp_cols * local_size):
T.attr(
T.comm_reducer(lambda x, y: x + y, [T.float16(0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
)
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
C_local[n],
True,
reduced_accum_res[0],
rk,
dtype="handle",
))
if rk == 0:
C_local[n] = reduced_accum_res[0]
if rk == 0:
mma_emitter.stmatrix(
C_local,
C_shared,
)
for i, j in T.Parallel(block_M, (block_N // reduce_k)):
vj = rk * (block_N // reduce_k) + j
C[by * block_M + i,
bx * block_N + vj] = C_shared[i // micro_size_x, vj // micro_size_y,
i % micro_size_x, vj % micro_size_y]
return main
def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
transform_b,
):
import bitblas
matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
M, N, K, in_dtype, out_dtype, accum_dtype, transform_b)
kernel = tilelang.compile(matmul, out_idx=[2])
profiler = kernel.get_profiler()
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
assert src_code is not None
num_bits = 4
num_elems_per_byte = 8 // num_bits
storage_dtype = "int8"
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
qB = torch.randint(
0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
ladder_permutate_config = bitblas.ops.LadderPermutateConfig(
M=N,
N=K,
transform_kind=transform_b,
transpose_matrix=True,
dequantize_bits=num_bits,
storage_dtype=storage_dtype,
)
ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config)
lop3_permutate_config = bitblas.ops.LOP3PermutateConfig(
M=N,
N=K,
datatype=in_dtype,
dequantize_bits=num_bits,
storage_dtype=storage_dtype,
)
lop3_permutate = bitblas.ops.LOP3Permutate(
config=lop3_permutate_config,
target=tvm.target.Target("llvm"),
)
QLB = ladder_permutate(qB.cpu()).cuda()
QLB = lop3_permutate(QLB.cpu()).cuda()
C = kernel(A, QLB)
latency = profiler.do_bench()
# Ensure that the latency is not None
assert latency is not None
B = (
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4,
dtype=torch.half).to(torch.half).to(A.device))
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
# Get Reference Result
ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype))
print("Ref C: ", ref_c)
print("C: ", C)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
@tilelang.testing.requires_package("bitblas")
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
def test_run_dequantize_gemm():
run_gemm(256, 256, 256, "float16", "float16", "float16", 128, 128, 32, num_threads=128)
run_gemm(256, 256, 256, "int8", "int32", "int32", 128, 128, 32, num_threads=128)
@tilelang.testing.requires_package("bitblas")
@tilelang.testing.requires_llvm
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4():
assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(
256, 1024, 512, "float16", "float16", "float16", 3)
if __name__ == "__main__":
# tilelang.testing.main()
test_fp4_fp16_convert_close()
......@@ -20,7 +20,15 @@ def reshape_test(N, M, dtype):
def run_reshape(N, M, dtype):
program = reshape_test(N, M, dtype)
jit_kernel = tl.compile(program, out_idx=-1)
# TODO(lei): reshape cannot apply shared memory
# layout transform propagation
jit_kernel = tl.compile(
program,
out_idx=-1,
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
profiler = jit_kernel.get_profiler()
def ref_program(A):
......@@ -56,7 +64,15 @@ def reshape_test_smem_1d_2_2d(N, M, dtype):
def run_reshape_smem_1d_2_2d(N, M, dtype):
program = reshape_test_smem_1d_2_2d(N, M, dtype)
jit_kernel = tl.compile(program, out_idx=-1)
# TODO(lei): reshape cannot apply shared memory
# layout transform propagation
jit_kernel = tl.compile(
program,
out_idx=-1,
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
profiler = jit_kernel.get_profiler()
def ref_program(A):
......
......@@ -70,7 +70,7 @@ def matmul_sp(
backend="cutlass",
block_k=block_K),
})
T.no_set_max_nreg()
T.disable_warp_group_reg_alloc()
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
......
from tilelang import tvm as tvm
import tilelang as tl
import tilelang.language as T
import tilelang.testing
from tvm import tir
tilelang.disable_cache()
def test_inject_set_max_nreg():
"""Test the InjectSetMaxNReg pass"""
@T.prim_func
def before(A: T.Tensor((512, 512), "float16"), B: T.Tensor((512, 512), "float16")):
bx = T.launch_thread("blockIdx.x", 8)
by = T.launch_thread("blockIdx.y", 8)
v = T.launch_thread("threadIdx.x", 128)
with T.block(""):
T.reads(A[by * 64, 0:512], B[0:512, bx * 64])
T.writes()
# Add set_max_nreg hints
T.annotate_producer_reg_dealloc(24) # Producer: decrease to 24
T.annotate_consumer_reg_alloc(240) # Consumer: increase to 240
A_shared = T.alloc_buffer((3, 1, 8, 256), "float16", scope="shared.dyn")
B_shared = T.alloc_buffer((3, 1, 4, 512), "float16", scope="shared.dyn")
C_local = T.alloc_buffer((32,), scope="local")
T.create_list_of_mbarrier(128, 128, 128, 128, 128, 128)
T.attr([128, 128], "kWarpSpecializationScope", 0)
if v >= 128:
# Producer branch - should have set_max_nreg(24, 0)
for k in range(16):
T.mbarrier_wait_parity(T.get_mbarrier(k % 3 + 3), T.bitwise_xor(k // 3 % 2, 1))
if v - 128 == 0:
T.tma_load(
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1,
0, 2, 2, 0), T.get_mbarrier(k % 3),
T.tvm_access_ptr(
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2),
k * 32, by * 64)
T.evaluate(
tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3)]))
else:
# Consumer branch - should have set_max_nreg(240, 1)
for k in range(16):
T.mbarrier_wait_parity(T.get_mbarrier(k % 3), k // 3 % 2)
T.call_extern(
"handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(
T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))
T.evaluate(
tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)]))
# Apply the InjectSetMaxNReg pass
func = before
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
mod = tl.transform.AnnotateWarpGroupRegAlloc()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
# Check that set_max_nreg calls are properly injected
main_func = mod["main"]
set_max_nreg_calls = []
def collect_set_max_nreg(stmt):
if (isinstance(stmt, tvm.tir.Evaluate) and hasattr(stmt.value, 'op') and
hasattr(stmt.value.op, 'name') and stmt.value.op.name == "tl.set_max_nreg"):
set_max_nreg_calls.append(stmt.value)
tvm.tir.stmt_functor.post_order_visit(main_func.body, collect_set_max_nreg)
# We should have at least 2 set_max_nreg calls (one for producer, one for consumer)
assert len(set_max_nreg_calls
) >= 2, f"Expected at least 2 set_max_nreg calls, got {len(set_max_nreg_calls)}"
# Check that we have the expected register values
reg_values = [call[0] for call in set_max_nreg_calls]
assert 24 in reg_values, f"Expected register value 24 in {reg_values}"
assert 240 in reg_values, f"Expected register value 240 in {reg_values}"
print("InjectSetMaxNReg test passed!")
def test_inject_set_max_nreg_no_set_max_nreg():
"""Test the InjectSetMaxNReg pass with no_set_max_nreg"""
@T.prim_func
def before_no_set_max_nreg(A: T.Tensor((512, 512), "float16")):
bx = T.launch_thread("blockIdx.x", 8)
v = T.launch_thread("threadIdx.x", 128)
with T.block(""):
T.reads(A[bx * 64, 0:64])
T.writes()
# Add no_set_max_nreg to disable register hinting
T.disable_warp_group_reg_alloc()
T.create_list_of_mbarrier(128, 128)
T.attr([128, 128], "kWarpSpecializationScope", 0)
if v >= 128:
# Producer branch - should not have set_max_nreg calls
T.evaluate(0)
else:
# Consumer branch - should not have set_max_nreg calls
T.evaluate(0)
# Apply the InjectSetMaxNReg pass
func = before_no_set_max_nreg
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
mod = tl.transform.AnnotateWarpGroupRegAlloc()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
# Check that no set_max_nreg calls are injected when no_set_max_nreg is present
main_func = mod["main"]
set_max_nreg_calls = []
def collect_set_max_nreg(stmt):
if (isinstance(stmt, tvm.tir.Evaluate) and hasattr(stmt.value, 'op') and
hasattr(stmt.value.op, 'name') and stmt.value.op.name == "tl.set_max_nreg"):
set_max_nreg_calls.append(stmt.value)
tvm.tir.stmt_functor.post_order_visit(main_func.body, collect_set_max_nreg)
# Should have no set_max_nreg calls when no_set_max_nreg is present
assert len(
set_max_nreg_calls
) == 0, f"Expected 0 set_max_nreg calls when no_set_max_nreg is present, got {len(set_max_nreg_calls)}"
print("InjectSetMaxNReg with no_set_max_nreg test passed!")
if __name__ == "__main__":
tilelang.testing.main()
......@@ -101,6 +101,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.MultiVersionBuffer()(mod)
mod = tilelang.transform.WarpSpecialized()(mod)
mod = tilelang.transform.InjectTmaBarrier()(mod)
mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod)
# if tma is not enabled, we can also do pipeline planning
# to get better performance with async copy
mod = tilelang.transform.PipelinePlanning()(mod)
......
......@@ -232,6 +232,9 @@ class Environment:
def disable_cache(self) -> None:
CacheState.disable()
def is_print_on_compilation_enabled(self) -> bool:
return self.TILELANG_PRINT_ON_COMPILATION.lower() in ("1", "true", "yes", "on")
# Instantiate as a global configuration object
env = Environment()
......
......@@ -117,7 +117,7 @@ class JITKernel(object):
# Print log on compilation starts
# NOTE(Chenggang): printing could let the training/inference framework easier to know
# whether the communication timeout is from compilation
if env.TILELANG_PRINT_ON_COMPILATION.lower() in ("1", "true", "yes", "on"):
if env.is_print_on_compilation_enabled():
# assert func must have "global_symbol"
func_name = func.attrs.get("global_symbol")
assert func_name is not None, "func must have global_symbol"
......@@ -126,6 +126,11 @@ class JITKernel(object):
# Compile the TileLang function and create a kernel adapter for execution.
adapter = self._compile_and_create_adapter(func, out_idx)
if env.is_print_on_compilation_enabled():
func_name = func.attrs.get("global_symbol")
assert func_name is not None, "func must have global_symbol"
logger.info(f"TileLang completes to compile kernel `{func_name}`")
# The adapter's function is assigned as the callable function for this instance.
self.adapter = adapter
self.torch_function = adapter.func
......
......@@ -142,12 +142,30 @@ def dec_max_nreg(reg_count: int):
return set_max_nreg(reg_count, 0)
def annotate_producer_reg_dealloc(reg_count: int = 24):
"""Annotate the producer reg dealloc.
"""
return dec_max_nreg(reg_count)
def annotate_consumer_reg_alloc(reg_count: int = 240):
"""Annotate the consumer reg alloc.
"""
return inc_max_nreg(reg_count)
def no_set_max_nreg():
"""Disable the maximum register limit setting.
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.no_set_max_nreg"))
def disable_warp_group_reg_alloc():
"""Disable the warp group reg alloc.
"""
return no_set_max_nreg()
def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr, tir.Call], parity: Union[int, Var]):
"""Wait for memory barrier parity condition.
......
......@@ -189,6 +189,21 @@ def WarpSpecialized():
return _ffi_api.WarpSpecialized() # type: ignore
def AnnotateWarpGroupRegAlloc():
"""Inject set_max_nreg calls into warp-specialized functions.
This pass analyzes the function to collect register hints from set_max_nreg
and no_set_max_nreg calls, then injects appropriate set_max_nreg calls into
producer and consumer branches of warp-specialized code.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.AnnotateWarpGroupRegAlloc() # type: ignore
def InjectTmaBarrier():
"""InjectTmaBarrier
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment