Commit 280e6627 authored by Zhengju Tang's avatar Zhengju Tang Committed by LeiWang1999
Browse files

[Dynamic Symbolic] Add pass_config to customize vectorization and tail split (#383)



* [Dynamic Symbolic] Add pass_config to customize vectorization and tail split

* Lint

* Only check for vectorized dimension. Add docs.

* Lint

* Update comment for cache directory in .gitignore

* Use CUTLASS convention to represent dynamic alignment. Fix bugs

* Add benchmark examples

* Add more benchmarks. Fix accumulate type bug.

* Lint

* Lint

* Test Lint

* Lint

* Test Lint

* Lint

* Fix typo

* Lint

* Lint

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent 310fea95
...@@ -89,3 +89,6 @@ tilelang/lib ...@@ -89,3 +89,6 @@ tilelang/lib
# cython # cython
tilelang/jit/adapter/cython/.cycache tilelang/jit/adapter/cython/.cycache
# cache directory for clangd
.cache/
import tilelang
import tilelang.language as T
import tilelang.testing
from tilelang import tvm as tvm
tilelang.testing.set_random_seed(0)
tilelang.disable_cache()
def matmul_dynamic_mnk(
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
M = tvm.te.var("m")
N = tvm.te.var("n")
K = tvm.te.var("k")
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def test_matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype,
accum_dtype, num_stages, threads):
print(
f"M: {M}, N: {N}, K: {K}, block_M: {block_M}, block_N: {block_N}, block_K: {block_K}, trans_A: {trans_A}, trans_B: {trans_B}, in_dtype: {in_dtype}, out_dtype: {out_dtype}, accum_dtype: {accum_dtype}, num_stages: {num_stages}, threads: {threads}"
)
program = matmul_dynamic_mnk(block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype,
accum_dtype, num_stages, threads)
kernel = tilelang.compile(
program, pass_configs={
"tl.disable_dynamic_tail_split": True,
"tl.dynamic_alignment": 8
})
import torch
if trans_A:
A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype))
else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
if trans_B:
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
else:
B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
kernel(A, B, C)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
# Get Reference Result
ref_c = ref_program(A, B)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench(input_tensors=[A, B, C])
print(f"Latency: {latency} ms")
if __name__ == "__main__":
test_matmul_dynamic(16384, 16384, 16384, 128, 128, 32, False, False, "float16", "float16",
"float32", 3, 128)
...@@ -19,6 +19,8 @@ namespace tl { ...@@ -19,6 +19,8 @@ namespace tl {
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableTMALower, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableTMALower, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer); TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDynamicTailSplit, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDynamicAlignment, Integer);
#define TIR_DEFINE_TL_BUILTIN(OpName) \ #define TIR_DEFINE_TL_BUILTIN(OpName) \
const Op &OpName() { \ const Op &OpName() { \
......
...@@ -18,6 +18,27 @@ static constexpr const char *kDisableWarpSpecialized = ...@@ -18,6 +18,27 @@ static constexpr const char *kDisableWarpSpecialized =
"tl.disable_warp_specialized"; "tl.disable_warp_specialized";
static constexpr const char *kConfigIndexBitwidth = "tl.config_index_bitwidth"; static constexpr const char *kConfigIndexBitwidth = "tl.config_index_bitwidth";
/*!
* \brief Whether to disable dynamic tail split
*
* kDisableDynamicTailSplit = "tl.disable_dynamic_tail_split"
*
*/
static constexpr const char *kDisableDynamicTailSplit =
"tl.disable_dynamic_tail_split";
/*!
* \brief The size of the vectorized dimension in buffer, designed by user
*
* For example, if the vectorized dimension is 128 bits and the dtype of buffer
* A[m, k] is float16, the size of the vectorized dimension (i.e. k) in buffer A
* should be divisible by 8 (8 = 128 / 16).
*
* kDynamicAlignment = "tl.dynamic_alignment"
*
*/
static constexpr const char *kDynamicAlignment = "tl.dynamic_alignment";
/*! /*!
* \brief tvm intrinsics for TMADescriptor creation for tiled load * \brief tvm intrinsics for TMADescriptor creation for tiled load
* *
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* \brief Reference to loop_vectorize.cc and vectorize_loop.cc * \brief Reference to loop_vectorize.cc and vectorize_loop.cc
*/ */
#include <cstdint>
#include <tvm/arith/iter_affine_map.h> #include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
...@@ -13,6 +14,7 @@ ...@@ -13,6 +14,7 @@
#include "../layout/layout.h" #include "../layout/layout.h"
#include "../layout/utils.h" #include "../layout/utils.h"
#include "../op/builtin.h"
#include "arith/int_operator.h" #include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h" #include "arith/ir_visitor_with_analyzer.h"
#include "common/loop_vectorization_utils.h" #include "common/loop_vectorization_utils.h"
...@@ -60,7 +62,17 @@ bool IndiceCanVectorizeDynamic(PrimExpr expr, Var var, PrimExpr iter_var_size, ...@@ -60,7 +62,17 @@ bool IndiceCanVectorizeDynamic(PrimExpr expr, Var var, PrimExpr iter_var_size,
class VectorizePlannerDynamic : public arith::IRVisitorWithAnalyzer { class VectorizePlannerDynamic : public arith::IRVisitorWithAnalyzer {
public: public:
VectorizePlannerDynamic() = default; VectorizePlannerDynamic(int dynamic_alignment,
bool disable_dynamic_tail_split)
: dynamic_alignment_(dynamic_alignment),
disable_dynamic_tail_split_(disable_dynamic_tail_split),
vector_load_bits_max_(128) {
if (disable_dynamic_tail_split_) {
vector_size_ = dynamic_alignment_;
} else {
vector_size_ = vector_load_bits_max_;
}
}
int Plan(const For &node) { int Plan(const For &node) {
this->operator()(node); this->operator()(node);
...@@ -167,21 +179,29 @@ private: ...@@ -167,21 +179,29 @@ private:
&analyzer_)) { &analyzer_)) {
vector_size_ /= 2; vector_size_ /= 2;
} }
} else if (vector_size_ <= vector_load_bits_max_ / buffer->dtype.bits()) { } else {
// dynamic shape load: get the vectorization condition // dynamic shape load: get the vectorization condition
dynamic_ = true; dynamic_ = true;
if (!disable_dynamic_tail_split_ &&
vector_size_ >= vector_load_bits_max_ / buffer->dtype.bits()) {
vector_size_ = vector_load_bits_max_ / buffer->dtype.bits();
}
PrimExpr offset = buffer.OffsetOf(indices).back(); PrimExpr offset = buffer.OffsetOf(indices).back();
// condition for alignment, maybe useless // condition for alignment, maybe useless
condition_ = (FloorMod(offset, vector_size_) == 0); condition_ = (FloorMod(offset, vector_size_) == 0);
} }
} }
const int vector_load_bits_max_ = 128; // Use dynamic alignment from pass config
int vector_load_bits_max_;
int dynamic_alignment_;
bool disable_dynamic_tail_split_;
int vector_size_;
const ForNode *inner_for_; const ForNode *inner_for_;
Map<Var, Range> iter_map_; Map<Var, Range> iter_map_;
bool has_nonlocal_memory_access_ = false; bool has_nonlocal_memory_access_ = false;
int vector_size_ = 128;
// conditionally vectorize // conditionally vectorize
bool dynamic_ = false; bool dynamic_ = false;
PrimExpr condition_; PrimExpr condition_;
...@@ -324,12 +344,21 @@ private: ...@@ -324,12 +344,21 @@ private:
class VectorizeRewriterDynamic : public StmtExprMutator { class VectorizeRewriterDynamic : public StmtExprMutator {
public: public:
VectorizeRewriterDynamic(VectorizePlanResult plan) VectorizeRewriterDynamic(VectorizePlanResult plan,
bool disable_dynamic_tail_split)
: vector_size_(plan.vector_size), condition_(plan.condition), : vector_size_(plan.vector_size), condition_(plan.condition),
dynamic_(plan.dynamic) {} dynamic_(plan.dynamic),
disable_dynamic_tail_split_(disable_dynamic_tail_split) {}
private: private:
Stmt VisitStmt_(const ForNode *node) final { Stmt VisitStmt_(const ForNode *node) final {
// Get pass config `tl.disable_dynamic_tail_split`
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
Optional<Bool> opt_disable_dynamic_tail_split =
ctxt->GetConfig(kDisableDynamicTailSplit, Optional<Bool>());
bool disable_dynamic_tail_split =
opt_disable_dynamic_tail_split.value_or(Bool(false));
inner_for_ = node; inner_for_ = node;
auto ret = StmtExprMutator::VisitStmt_(node); auto ret = StmtExprMutator::VisitStmt_(node);
if (inner_for_ != node) { if (inner_for_ != node) {
...@@ -365,28 +394,47 @@ private: ...@@ -365,28 +394,47 @@ private:
condition_bound = condition_bound && condition_mutator(conditions[i]); condition_bound = condition_bound && condition_mutator(conditions[i]);
} }
// modify body in the vectorized loop if (!disable_dynamic_tail_split) {
VectorizedBodyMutator mutator(inner_var, vector_size_, conditions); // If dynamic_tail_split is true, we will vectorize the loop with
Stmt vectorize_body = mutator(body); // if-then-else conditions modify body in the vectorized loop
VectorizedBodyMutator mutator(inner_var, vector_size_, conditions);
// add condition ifthenelse here Stmt vectorize_body = mutator(body);
For vectorize_for =
For(inner_var, 0, vector_size_, ForKind::kVectorized, vectorize_body); // add condition ifthenelse here
For serial_for = For(inner_var, 0, vector_size_, ForKind::kSerial, body); For vectorize_for =
body = IfThenElse(condition_bound, vectorize_for, serial_for); For(inner_var, 0, vector_size_, ForKind::kVectorized, vectorize_body);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body, For serial_for = For(inner_var, 0, vector_size_, ForKind::kSerial, body);
fnode->thread_binding, fnode->annotations, fnode->span); body = IfThenElse(condition_bound, vectorize_for, serial_for);
return body; body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
return body;
} else {
// If dynamic_tail_split is false, we will directly vectorize the loop
// without dynamic tail split and if_then_else, which may lead to error
VectorizedBodyMutator mutator(inner_var, vector_size_, conditions);
Stmt vectorize_body = mutator(body);
For vectorize_for =
For(inner_var, 0, vector_size_, ForKind::kVectorized, vectorize_body);
body =
For(outer_var, 0, extent / vector_size_, fnode->kind, vectorize_for,
fnode->thread_binding, fnode->annotations, fnode->span);
return body;
}
} }
const ForNode *inner_for_; const ForNode *inner_for_;
const int vector_size_; const int vector_size_;
const PrimExpr condition_; const PrimExpr condition_;
const bool dynamic_; const bool dynamic_;
const bool disable_dynamic_tail_split_;
}; };
VectorizePlanResult GetVectorizePlanResultDynamic(const For &loop) { VectorizePlanResult
VectorizePlannerDynamic planner; GetVectorizePlanResultDynamic(const For &loop, int dynamic_alignment,
bool disable_dynamic_tail_split) {
VectorizePlannerDynamic planner(dynamic_alignment,
disable_dynamic_tail_split);
int vector_size = planner.Plan(loop); int vector_size = planner.Plan(loop);
bool dynamic = planner.GetDynamic(); bool dynamic = planner.GetDynamic();
PrimExpr condition = planner.GetCondition(); PrimExpr condition = planner.GetCondition();
...@@ -395,30 +443,40 @@ VectorizePlanResult GetVectorizePlanResultDynamic(const For &loop) { ...@@ -395,30 +443,40 @@ VectorizePlanResult GetVectorizePlanResultDynamic(const For &loop) {
class LoopVectorizerDynamic : public IRMutatorWithAnalyzer { class LoopVectorizerDynamic : public IRMutatorWithAnalyzer {
public: public:
static Stmt Substitute(Stmt stmt) { static Stmt Substitute(Stmt stmt, bool disable_dynamic_tail_split,
int dynamic_alignment) {
arith::Analyzer analyzer; arith::Analyzer analyzer;
LoopVectorizerDynamic substituter(&analyzer); LoopVectorizerDynamic substituter(&analyzer, disable_dynamic_tail_split,
dynamic_alignment);
stmt = substituter.VisitStmt(stmt); stmt = substituter.VisitStmt(stmt);
return stmt; return stmt;
} }
private: private:
LoopVectorizerDynamic(arith::Analyzer *analyzer) LoopVectorizerDynamic(arith::Analyzer *analyzer,
: arith::IRMutatorWithAnalyzer(analyzer) {} bool disable_dynamic_tail_split, int dynamic_alignment)
: arith::IRMutatorWithAnalyzer(analyzer),
disable_dynamic_tail_split_(disable_dynamic_tail_split),
dynamic_alignment_(dynamic_alignment) {}
Stmt VisitStmt_(const ForNode *op) final { Stmt VisitStmt_(const ForNode *op) final {
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op)); For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
VectorizePlanResult res{128, false, 0}; VectorizePlanResult res{vector_load_bits_max_, false, 0};
res = GetVectorizePlanResultDynamic(for_node); res = GetVectorizePlanResultDynamic(for_node, dynamic_alignment_,
disable_dynamic_tail_split_);
NestedLoopChecker checker; NestedLoopChecker checker;
int nest_num = checker.GetNestLoopNum(for_node); int nest_num = checker.GetNestLoopNum(for_node);
if (nest_num > 1) { // only rewrite the innermost loop if (nest_num > 1) { // only rewrite the innermost loop
return for_node; return for_node;
} }
int vectorize_hint = res.vector_size; int vectorize_hint = res.vector_size;
auto rewriter = VectorizeRewriterDynamic(res); auto rewriter = VectorizeRewriterDynamic(res, disable_dynamic_tail_split_);
return Downcast<For>(rewriter(for_node)); return Downcast<For>(rewriter(for_node));
} }
const int vector_load_bits_max_ = 128;
int dynamic_alignment_;
bool disable_dynamic_tail_split_;
}; };
class VectorizeSkipperDynamic : public StmtMutator { class VectorizeSkipperDynamic : public StmtMutator {
...@@ -437,8 +495,21 @@ public: ...@@ -437,8 +495,21 @@ public:
tvm::transform::Pass LoopVectorizeDynamic() { tvm::transform::Pass LoopVectorizeDynamic() {
using namespace tir::transform; using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
bool disable_dynamic_tail_split =
ctx->GetConfig<Bool>(kDisableDynamicTailSplit, Bool(true)).value();
int dynamic_alignment =
(int)(ctx->GetConfig<Integer>(kDynamicAlignment, Integer(8))
.value_or(Integer(8))
->value);
// Ensure tl.dynamic_alignment is a power of 2
if (disable_dynamic_tail_split &&
((dynamic_alignment & (dynamic_alignment - 1)) != 0)) {
LOG(FATAL) << "tl.dynamic_alignment must be a power of 2, but got "
<< dynamic_alignment;
}
auto *n = f.CopyOnWrite(); auto *n = f.CopyOnWrite();
n->body = tvm::tl::LoopVectorizerDynamic::Substitute(std::move(n->body)); n->body = LoopVectorizerDynamic::Substitute(
std::move(n->body), disable_dynamic_tail_split, dynamic_alignment);
return f; return f;
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.LoopVectorizeDynamic", {}); return CreatePrimFuncPass(pass_func, 0, "tl.LoopVectorizeDynamic", {});
......
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "../op/builtin.h"
#include "tir/transforms/arg_binder.h" #include "tir/transforms/arg_binder.h"
#include "tir/transforms/ir_utils.h" #include "tir/transforms/ir_utils.h"
...@@ -273,6 +274,10 @@ PrimFunc MakePackedAPI(PrimFunc func) { ...@@ -273,6 +274,10 @@ PrimFunc MakePackedAPI(PrimFunc func) {
std::vector<Stmt> seq_init, seq_check, arg_buffer_declarations; std::vector<Stmt> seq_init, seq_check, arg_buffer_declarations;
std::unordered_map<const VarNode *, PrimExpr> vmap; std::unordered_map<const VarNode *, PrimExpr> vmap;
ArgBinder binder(&vmap); ArgBinder binder(&vmap);
std::vector<Stmt> shape_checks;
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
bool disable_dynamic_tail_split =
ctxt->GetConfig<Bool>(kDisableDynamicTailSplit, Bool(true)).value();
// --------------------------- // ---------------------------
// local function definitions // local function definitions
...@@ -416,12 +421,44 @@ PrimFunc MakePackedAPI(PrimFunc func) { ...@@ -416,12 +421,44 @@ PrimFunc MakePackedAPI(PrimFunc func) {
} }
} }
// (zhengju) For dynamic constraint, we need to check the buffer shape and
// dtype to make sure the buffer can be vectorized.
for (const auto &kv : buffer_def) {
if (disable_dynamic_tail_split) {
Optional<Integer> opt_dynamic_alignment =
ctxt->GetConfig(kDynamicAlignment, Optional<Integer>());
int dynamic_alignment = opt_dynamic_alignment.value_or(Integer(8))->value;
// The vectorize dimension will be the last dimension of the buffer
auto vectorize_dim = kv.second->shape[kv.second->shape.size() - 1];
auto shape_vectorize_expr = [&]() -> PrimExpr {
PrimExpr result = IntImm(kv.second->DefaultIndexType(), 1);
result = result * vectorize_dim;
result = FloorMod(result, dynamic_alignment);
return result;
}();
shape_checks.emplace_back(AssertStmt(
shape_vectorize_expr == 0,
tvm::tir::StringImm(
kv.second->name +
": Vectorize dimension in buffer must be divisible by " +
std::to_string(dynamic_alignment)),
nop));
}
}
// Return error code of zero on success // Return error code of zero on success
body = SeqStmt({body, Evaluate(ret(Integer(0)))}); body = SeqStmt({body, Evaluate(ret(Integer(0)))});
body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts(), if (!disable_dynamic_tail_split) {
arg_buffer_declarations}, body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts(),
body); arg_buffer_declarations},
body);
} else {
body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts(),
arg_buffer_declarations, shape_checks},
body);
}
func_ptr->body = body; func_ptr->body = body;
func_ptr->params = args; func_ptr->params = args;
......
...@@ -387,6 +387,71 @@ def assert_tl_matmul_block_all_dynamic_correctness( ...@@ -387,6 +387,71 @@ def assert_tl_matmul_block_all_dynamic_correctness(
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
def assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
dynamic_alignment=8,
):
program = tl_matmul_block_all_dynamic(
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
kernel = tilelang.compile(
program,
pass_configs={
"tl.disable_dynamic_tail_split": dynamic_alignment != 0,
"tl.dynamic_alignment": dynamic_alignment
})
if trans_A:
A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype))
else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
if trans_B:
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
else:
B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
kernel(A, B, C)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
# Get Reference Result
ref_c = ref_program(A, B)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
def test_assert_tl_matmul_macro(): def test_assert_tl_matmul_macro():
assert_tl_matmul_macro_correctness(128, 128, 128, "float16", "float16", "float16") assert_tl_matmul_macro_correctness(128, 128, 128, "float16", "float16", "float16")
assert_tl_matmul_macro_correctness(66, 128, 128, "float16", "float16", "float16") assert_tl_matmul_macro_correctness(66, 128, 128, "float16", "float16", "float16")
...@@ -411,6 +476,39 @@ def test_assert_tl_matmul_block_all_dynamic(): ...@@ -411,6 +476,39 @@ def test_assert_tl_matmul_block_all_dynamic():
"float16", 64, 64, 32) "float16", 64, 64, 32)
def test_assert_tl_matmul_block_all_dynamic_with_pass_config():
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
128,
128,
128,
False,
False,
"float16",
"float16",
"float16",
64,
64,
32,
dynamic_alignment=8)
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
64,
128,
128,
False,
False,
"float16",
"float16",
"float16",
64,
64,
32,
dynamic_alignment=8)
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
64, 128, 60, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=4)
# Tail split is enabled with dynamic alignment 0
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
64, 128, 64, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=0)
if __name__ == "__main__": if __name__ == "__main__":
# tilelang.testing.main() tilelang.testing.main()
test_assert_tl_matmul_macro()
import torch
import torch.backends
from tilelang import tvm as tvm
import tilelang.testing
import tilelang.language as T
tilelang.testing.set_random_seed(0)
tilelang.disable_cache()
def tl_matmul_block_static(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
num_threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor(
(M, N), out_dtype)):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def assert_tl_matmul_block_static(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages=3,
num_threads=128,
):
program = tl_matmul_block_static(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
num_threads,
)
kernel = tilelang.compile(program)
if trans_A:
A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype))
else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
if trans_B:
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
else:
B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
kernel(A, B, C)
# print(kernel.get_kernel_source())
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
# Get Reference Result
ref_c = ref_program(A, B)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench()
print(f"Static Latency: {latency} ms")
def tl_matmul_block_dynamic_m(
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
num_threads,
):
M = tvm.te.var("m")
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor(
(M, N), out_dtype)):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def assert_tl_matmul_block_dynamic_m(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages=3,
num_threads=128,
pass_configs=None,
):
program = tl_matmul_block_dynamic_m(
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
kernel = tilelang.compile(program, pass_configs=pass_configs)
if trans_A:
A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype))
else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
if trans_B:
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
else:
B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
kernel(A, B, C)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
# Get Reference Result
ref_c = ref_program(A, B)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench(input_tensors=[A, B, C])
print(f"Dynamic M Latency with pass_configs: {pass_configs} is {latency} ms")
def tl_matmul_block_dynamic_mn(
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
num_threads,
):
M = tvm.te.var("m")
N = tvm.te.var("n")
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor(
(M, N), out_dtype)):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def assert_tl_matmul_block_dynamic_mn(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages=3,
num_threads=128,
pass_configs=None,
):
program = tl_matmul_block_dynamic_mn(
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
kernel = tilelang.compile(program, pass_configs=pass_configs)
if trans_A:
A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype))
else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
if trans_B:
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
else:
B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
kernel(A, B, C)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
# Get Reference Result
ref_c = ref_program(A, B)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench(input_tensors=[A, B, C])
print(f"Dynamic MN Latency with pass_configs: {pass_configs} is {latency} ms")
def tl_matmul_block_dynamic_mnk(
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
num_threads,
):
M = tvm.te.var("m")
N = tvm.te.var("n")
K = tvm.te.var("k")
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor(
(M, N), out_dtype)):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def assert_tl_matmul_block_dynamic_mnk(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages=3,
num_threads=128,
pass_configs=None,
):
program = tl_matmul_block_dynamic_mnk(
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
kernel = tilelang.compile(program, pass_configs=pass_configs)
if trans_A:
A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype))
else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
if trans_B:
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
else:
B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
kernel(A, B, C)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
# Get Reference Result
ref_c = ref_program(A, B)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench(input_tensors=[A, B, C])
print(f"Dynamic MNK Latency with pass_configs: {pass_configs} is {latency} ms")
def test_assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K):
assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K, False, False, "float16",
"float16", "float32")
def test_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
assert_tl_matmul_block_dynamic_m(
M,
N,
K,
block_M,
block_N,
block_K,
False,
False,
"float16",
"float16",
"float32",
pass_configs={
"tl.disable_dynamic_tail_split": True,
"tl.dynamic_alignment": 8
})
assert_tl_matmul_block_dynamic_m(
M,
N,
K,
block_M,
block_N,
block_K,
False,
False,
"float16",
"float16",
"float32",
pass_configs={"tl.disable_dynamic_tail_split": False})
def test_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K):
assert_tl_matmul_block_dynamic_mn(
M,
N,
K,
block_M,
block_N,
block_K,
False,
False,
"float16",
"float16",
"float32",
pass_configs={
"tl.disable_dynamic_tail_split": True,
"tl.dynamic_alignment": 8
})
assert_tl_matmul_block_dynamic_mn(
M,
N,
K,
block_M,
block_N,
block_K,
False,
False,
"float16",
"float16",
"float32",
pass_configs={"tl.disable_dynamic_tail_split": False})
def test_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
assert_tl_matmul_block_dynamic_mnk(
M,
N,
K,
block_M,
block_N,
block_K,
False,
False,
"float16",
"float16",
"float32",
pass_configs={
"tl.disable_dynamic_tail_split": True,
"tl.dynamic_alignment": 8
})
assert_tl_matmul_block_dynamic_mnk(
M,
N,
K,
block_M,
block_N,
block_K,
False,
False,
"float16",
"float16",
"float32",
pass_configs={"tl.disable_dynamic_tail_split": False})
def assert_all():
test_assert_tl_matmul_block_static(16384, 16384, 16384, 128, 128, 32)
test_assert_tl_matmul_block_dynamic_m(16384, 16384, 16384, 128, 128, 32)
test_assert_tl_matmul_block_dynamic_mn(16384, 16384, 16384, 128, 128, 32)
test_assert_tl_matmul_block_dynamic_mnk(16384, 16384, 16384, 128, 128, 32)
if __name__ == "__main__":
assert_all()
...@@ -139,6 +139,8 @@ def compile( ...@@ -139,6 +139,8 @@ def compile(
"tl.disable_tma_lower": bool, default: False "tl.disable_tma_lower": bool, default: False
"tl.disable_warp_specialized": bool, default: False "tl.disable_warp_specialized": bool, default: False
"tl.config_index_bitwidth": int, default: None "tl.config_index_bitwidth": int, default: None
"tl.disable_dynamic_tail_split": bool, default: False
"tl.dynamic_vectorize_size_bits": int, default: 128
""" """
return cached( return cached(
func=func, func=func,
......
...@@ -66,6 +66,8 @@ class JITKernel(object): ...@@ -66,6 +66,8 @@ class JITKernel(object):
Available options: Available options:
"tir.disable_vectorize": bool, default: False "tir.disable_vectorize": bool, default: False
"tl.disable_tma_lower": bool, default: False "tl.disable_tma_lower": bool, default: False
"tl.disable_dynamic_tail_split": bool, default: False
"tl.dynamic_vectorize_size_bits": int, default: 128
from_database : bool, optional from_database : bool, optional
Whether to create a TorchFunction from a database. Whether to create a TorchFunction from a database.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment