Commit 6972aed7 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Language] Support explicit programming for identified warp groups (#445)

* [Refactor] Update KernelLaunch to clarify CPU and GPU kernel launch logic

* Added comments to distinguish between CPU and GPU kernel launch sections for better code readability.
* Changed the creation of empty blocks to use a consistent "root" identifier, enhancing clarity in frame management.

* [Refactor] Rename operations for consistency in lower_hopper_intrin and related files

* Updated function names from CamelCase to snake_case for better consistency across the codebase.
* Refactored calls to `CreateTMADescriptorOp`, `CreateListofMBarrierOp`, and similar functions to their new names: `create_tma_descriptor`, `create_list_of_mbarrier`, etc.
* Adjusted corresponding test cases to reflect these changes, ensuring compatibility with the new naming conventions.

* [Refactor] Rename operations to snake_case for consistency

* Updated function names from CamelCase to snake_case across various files, including `CreateTMADescriptorOp` to `create_tma_descriptor`, `GetMBarrierOp` to `get_mbarrier`, and others.
* Adjusted corresponding calls and definitions in the codebase to reflect these naming changes, ensuring uniformity and improved readability.
* Enhanced layout inference and loop partitioning logic to accommodate the new naming conventions.

* [Feature] Introduce Warp Specialization and Eliminate Storage Sync for MBarrier

* Added a new example `gemm_ws.py` demonstrating matrix multiplication with warp specialization using TileLang.
* Implemented `WarpSpecializeFrame` and `WarpSpecialize` functionality to manage warp group indices in TIR frames.
* Introduced `EliminateStorageSyncForMBarrier` transformation to optimize storage synchronization in mbarrier regions.
* Enhanced the TileLang API with new methods for retrieving block and thread extents.
* Updated the `LowerAndLegalize` and `OptimizeForTarget` functions to incorporate the new transformation.
* Improved layout inference and kernel launch logic for better performance and clarity.

* [Refactor] Clean up code formatting and improve readability

* Added blank lines for better separation of code blocks in `gemm_ws.py`, `phase.py`, `kernel.py`, and `warpgroup.py`.
* Reformatted the `tilelang.compile` call in `gemm_ws.py` for improved clarity.
* Updated comments in `warpgroup.py` to clarify the availability of the `WarpSpecialize` function for NVIDIA GPUs.
* Ensured consistent spacing and formatting across multiple files to enhance overall code readability.

* lint fix

* [Refactor] Update mbarrier functions for improved clarity and consistency

* Refactored `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` to accept explicit parameters for better readability.
* Updated calls in `gemm_ws.py` to use the new function signatures, enhancing code clarity.
* Adjusted `warpgroup.py` to remove unused thread extent variable, streamlining the code.
* Added detailed docstrings to clarify usage examples for memory barrier functions.

* Added blank lines in `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` for improved code readability and separation of logical sections.
parent 0fa03398
...@@ -51,7 +51,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -51,7 +51,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.NoSetMaxNReg() T.no_set_max_nreg()
loop_range = T.ceildiv(seqlen_kv, block_N) loop_range = T.ceildiv(seqlen_kv, block_N)
for k in T.Pipelined(loop_range, num_stages=2): for k in T.Pipelined(loop_range, num_stages=2):
T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], qKV_shared) T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], qKV_shared)
......
...@@ -203,7 +203,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N): ...@@ -203,7 +203,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
transpose_B=True, transpose_B=True,
policy=T.GemmWarpPolicy.FullRow, policy=T.GemmWarpPolicy.FullRow,
wg_wait=-1) wg_wait=-1)
T.WaitWgmma(1) T.wait_wgmma(1)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
...@@ -212,7 +212,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N): ...@@ -212,7 +212,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0) 0)
T.WaitWgmma(0) T.wait_wgmma(0)
T.copy(qkT, qkT_cast) T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
...@@ -225,7 +225,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N): ...@@ -225,7 +225,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
T.copy(dsT_cast, dsT_shared) T.copy(dsT_cast, dsT_shared)
T.clear(dq) T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1) T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1)
T.WaitWgmma(0) T.wait_wgmma(0)
for i, j in T.Parallel(block_N, dim): for i, j in T.Parallel(block_N, dim):
if k * block_N + i < seq_len: if k * block_N + i < seq_len:
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
......
import tilelang
import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype, "shared")
B_shared = T.alloc_shared((block_K, block_N), dtype, "shared")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# create mbarrier for tma
T.create_list_of_mbarrier(128, 128)
with T.ws(1):
for ko in range(T.ceildiv(K, block_K)):
T.mbarrier_wait_parity(1, ko ^ 1)
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.mbarrier_arrive(0)
with T.ws(0):
T.clear(C_local)
for ko in range(T.ceildiv(K, block_K)):
T.mbarrier_wait_parity(0, ko)
T.gemm(A_shared, B_shared, C_local)
T.mbarrier_arrive(1)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
# 1. Define the kernel (matmul) and compile/lower it into an executable module
func = matmul(128, 128, 64, 128, 128, 32)
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
execution_backend="cython",
pass_configs={
"tl.disable_warp_specialized": True,
"tl.disable_tma_lower": True,
})
print(jit_kernel.get_kernel_source())
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(128, 64, device="cuda", dtype=torch.float16)
b = torch.randn(64, 128, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
c = jit_kernel(a, b)
print(c)
# Reference multiplication using PyTorch
ref_c = a @ b
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
# cuda_source = jit_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel
profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
...@@ -220,5 +220,55 @@ TVM_REGISTER_GLOBAL("tl.Parallel").set_body_typed(ParallelFor); ...@@ -220,5 +220,55 @@ TVM_REGISTER_GLOBAL("tl.Parallel").set_body_typed(ParallelFor);
TVM_REGISTER_GLOBAL("tl.Pipelined").set_body_typed(PipelinedFor); TVM_REGISTER_GLOBAL("tl.Pipelined").set_body_typed(PipelinedFor);
TVM_REGISTER_GLOBAL("tl.KernelLaunch").set_body_typed(KernelLaunch); TVM_REGISTER_GLOBAL("tl.KernelLaunch").set_body_typed(KernelLaunch);
class WarpSpecializeFrameNode : public TIRFrameNode {
public:
Array<TIRFrame> frames;
void VisitAttrs(tvm::AttrVisitor *v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("frames", &frames);
}
static constexpr const char *_type_key = "tl.WarpSpecializeFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(WarpSpecializeFrameNode, TIRFrameNode);
public:
TVM_DLL void EnterWithScope() final {
for (auto frame = frames.begin(); frame != frames.end(); ++frame)
(*frame)->EnterWithScope();
}
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
TVM_DLL void ExitWithScope() final {
for (auto frame = frames.rbegin(); frame != frames.rend(); ++frame)
(*frame)->ExitWithScope();
}
};
class WarpSpecializeFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WarpSpecializeFrame,
TIRFrame,
WarpSpecializeFrameNode);
};
WarpSpecializeFrame WarpSpecialize(int warp_group_idx, PrimExpr thread_idx,
int warp_group_size = 128) {
ObjectPtr<WarpSpecializeFrameNode> n = make_object<WarpSpecializeFrameNode>();
PrimExpr min_bound =
max(0, IntImm(thread_idx.dtype(), warp_group_idx) * warp_group_size);
PrimExpr max_bound = min_bound + warp_group_size;
PrimExpr condition = thread_idx >= min_bound && thread_idx < max_bound;
IfFrame if_frame = If(condition);
n->frames.push_back(if_frame);
n->frames.push_back(Then());
return WarpSpecializeFrame(n);
}
TVM_REGISTER_NODE_TYPE(WarpSpecializeFrameNode);
TVM_REGISTER_GLOBAL("tl.WarpSpecialize").set_body_typed(WarpSpecialize);
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -30,92 +30,92 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDynamicAlignment, Integer); ...@@ -30,92 +30,92 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDynamicAlignment, Integer);
TVM_REGISTER_OP("tl." #OpName) \ TVM_REGISTER_OP("tl." #OpName) \
.set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) .set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName)
TIR_DEFINE_TL_BUILTIN(CreateListofMBarrierOp) TIR_DEFINE_TL_BUILTIN(create_list_of_mbarrier)
.set_num_inputs(-1) .set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(CreateTMADescriptorOp) TIR_DEFINE_TL_BUILTIN(create_tma_descriptor)
.set_num_inputs(-1) .set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure)); Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(CreateTMAIm2ColDescriptorOp) TIR_DEFINE_TL_BUILTIN(create_tma_im2col_descriptor)
.set_num_inputs(-1) .set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure)); Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(GetMBarrierOp) TIR_DEFINE_TL_BUILTIN(get_mbarrier)
.set_num_inputs(1) .set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure)); Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(TMALoadOp).set_num_inputs(-1).set_attr<TCallEffectKind>( TIR_DEFINE_TL_BUILTIN(tma_load).set_num_inputs(-1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque)); "TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(TMALoadIm2ColOp) TIR_DEFINE_TL_BUILTIN(tma_load_im2col)
.set_num_inputs(-1) .set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(TMAStoreOp) TIR_DEFINE_TL_BUILTIN(tma_store).set_num_inputs(-1).set_attr<TCallEffectKind>(
.set_num_inputs(-1) "TCallEffectKind", Integer(CallEffectKind::kOpaque));
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(MBarrierWaitParity) TIR_DEFINE_TL_BUILTIN(mbarrier_wait_parity)
.set_num_inputs(2) .set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(MBarrierExpectTX) TIR_DEFINE_TL_BUILTIN(mbarrier_expect_tx)
.set_num_inputs(2) .set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(LDMatrixOp) TIR_DEFINE_TL_BUILTIN(ptx_ldmatirx)
.set_num_inputs(4) .set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(STMatrixOp) TIR_DEFINE_TL_BUILTIN(ptx_stmatirx)
.set_num_inputs(-1) .set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(SyncThreadsPartialOp) TIR_DEFINE_TL_BUILTIN(sync_thread_partial)
.set_num_inputs(1) .set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(FenceProxyAsyncOp) TIR_DEFINE_TL_BUILTIN(fence_proxy_async)
.set_num_inputs(0) .set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(TMAStoreArrive) TIR_DEFINE_TL_BUILTIN(tma_store_arrive)
.set_num_inputs(0) .set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(TMAStoreWait) TIR_DEFINE_TL_BUILTIN(tma_store_wait)
.set_num_inputs(0) .set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(SetMaxNReg) TIR_DEFINE_TL_BUILTIN(set_max_nreg)
.set_num_inputs(2) .set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(NoSetMaxNReg) TIR_DEFINE_TL_BUILTIN(no_set_max_nreg)
.set_num_inputs(0) .set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(WaitWgmma).set_num_inputs(1).set_attr<TCallEffectKind>( TIR_DEFINE_TL_BUILTIN(wait_wgmma)
"TCallEffectKind", Integer(CallEffectKind::kOpaque)); .set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(PackB16Op).set_num_inputs(2).set_attr<TCallEffectKind>( TIR_DEFINE_TL_BUILTIN(pack_b16).set_num_inputs(2).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kPure)); "TCallEffectKind", Integer(CallEffectKind::kPure));
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
\ No newline at end of file
...@@ -42,31 +42,31 @@ static constexpr const char *kDynamicAlignment = "tl.dynamic_alignment"; ...@@ -42,31 +42,31 @@ static constexpr const char *kDynamicAlignment = "tl.dynamic_alignment";
/*! /*!
* \brief tvm intrinsics for TMADescriptor creation for tiled load * \brief tvm intrinsics for TMADescriptor creation for tiled load
* *
* CuTensorMap* CreateTMADescriptorOp(data_type, rank, global_addr, * CuTensorMap* create_tma_descriptor(data_type, rank, global_addr,
* global_shape..., global_stride..., smem_box..., smem_stride..., interleave, * global_shape..., global_stride..., smem_box..., smem_stride..., interleave,
* swizzle, l2_promotion, oob_fill) * swizzle, l2_promotion, oob_fill)
* *
*/ */
const Op &CreateTMADescriptorOp(); const Op &create_tma_descriptor();
/*! /*!
* \brief tvm intrinsics for TMADescriptor creation for image to column load * \brief tvm intrinsics for TMADescriptor creation for image to column load
* *
* CuTensorMap* CreateTMAIm2ColDescriptorOp(data_type, rank, global_addr, * CuTensorMap* create_tma_im2col_descriptor(data_type, rank, global_addr,
* global_shape..., global_stride..., elem_stride..., lower_corner..., * global_shape..., global_stride..., elem_stride..., lower_corner...,
* upper_corner..., smme_box_pixel, smem_box_channel, interleave, swizzle, * upper_corner..., smme_box_pixel, smem_box_channel, interleave, swizzle,
* l2_promotion, oob_fill) * l2_promotion, oob_fill)
* *
*/ */
const Op &CreateTMAIm2ColDescriptorOp(); const Op &create_tma_im2col_descriptor();
/*! /*!
* \brief Create a list of mbarrier with num_threads * \brief Create a list of mbarrier with num_threads
* *
* CreateListofMBarrierOp(num_threads0, num_threads1, ...) * create_list_of_mbarrier(num_threads0, num_threads1, ...)
* *
*/ */
const Op &CreateListofMBarrierOp(); const Op &create_list_of_mbarrier();
/*! /*!
* \brief Get the mbarrier with barrier_id * \brief Get the mbarrier with barrier_id
...@@ -74,83 +74,83 @@ const Op &CreateListofMBarrierOp(); ...@@ -74,83 +74,83 @@ const Op &CreateListofMBarrierOp();
* int64_t* GetMBarrier(barrier_id) * int64_t* GetMBarrier(barrier_id)
* *
*/ */
const Op &GetMBarrierOp(); const Op &get_mbarrier();
/*! /*!
* \brief tvm intrinsics for loading data from global tensor descriptor to * \brief tvm intrinsics for loading data from global tensor descriptor to
* shared memory * shared memory
* *
* TMALoadOp(descriptor, mbarrier, smem_data, coord_0, coord_1, ...) * tma_load(descriptor, mbarrier, smem_data, coord_0, coord_1, ...)
* *
*/ */
const Op &TMALoadOp(); const Op &tma_load();
/*! /*!
* \brief tvm intrinsics for loading image from global tensor to columns in * \brief tvm intrinsics for loading image from global tensor to columns in
* shared memory * shared memory
* *
* TMALoadOp(descriptor, mbarrier, smem_data, coord_0, coord_1, ..., * tma_load(descriptor, mbarrier, smem_data, coord_0, coord_1, ...,
* image_offset, ...) * image_offset, ...)
* *
*/ */
const Op &TMALoadIm2ColOp(); const Op &tma_load_im2col();
/*! /*!
* \brief tvm intrinsics for storing data from shared memory to global tensor * \brief tvm intrinsics for storing data from shared memory to global tensor
* descriptor * descriptor
* *
* TMAStoreOp(descriptor, smem_data, coord_0, coord_1, ...) * tma_store(descriptor, smem_data, coord_0, coord_1, ...)
* *
*/ */
const Op &TMAStoreOp(); const Op &tma_store();
/*! /*!
* \brief tvm intrinsics for mbarrier wait with parity bit * \brief tvm intrinsics for mbarrier wait with parity bit
* *
* MBarrierWaitParity(mbarrier, parity) * mbarrier_wait_parity(mbarrier, parity)
* *
*/ */
const Op &MBarrierWaitParity(); const Op &mbarrier_wait_parity();
/*! /*!
* \brief tvm intrinsics for mbarrier expect tx * \brief tvm intrinsics for mbarrier expect tx
* *
* MBarrierExpectTX(mbarrier, transaction_bytes) * mbarrier_expect_tx(mbarrier, transaction_bytes)
* *
*/ */
const Op &MBarrierExpectTX(); const Op &mbarrier_expect_tx();
/*! /*!
* \brief tvm intrinsics for ldmatrix * \brief tvm intrinsics for ldmatrix
* *
* LDMatrixOp(transposed, num, shared_addr, local_addr) * ptx_ldmatirx(transposed, num, shared_addr, local_addr)
* *
*/ */
const Op &LDMatrixOp(); const Op &ptx_ldmatirx();
/*! /*!
* \brief tvm intrinsics for stmatrix * \brief tvm intrinsics for stmatrix
* *
* LDMatrixOp(transposed, num, shared_addr, int32_values...) * ptx_ldmatirx(transposed, num, shared_addr, int32_values...)
* *
*/ */
const Op &STMatrixOp(); const Op &ptx_stmatirx();
/*! /*!
* \brief Pack two b16 value into a b32 value * \brief Pack two b16 value into a b32 value
* *
* int32 PackB16Op(b16_value, b16_value) * int32 pack_b16(b16_value, b16_value)
* *
*/ */
const Op &PackB16Op(); const Op &pack_b16();
/*! /*!
* \brief Similar to __syncthreads(), but can be used to sync partial threads * \brief Similar to __syncthreads(), but can be used to sync partial threads
* *
* SyncThreadsPartialOp(num_partial_threads or mbarrier) * sync_thread_partial(num_partial_threads or mbarrier)
* *
*/ */
const Op &SyncThreadsPartialOp(); const Op &sync_thread_partial();
/*! /*!
* \brief Issue a shared memory fence for async operations * \brief Issue a shared memory fence for async operations
...@@ -158,23 +158,23 @@ const Op &SyncThreadsPartialOp(); ...@@ -158,23 +158,23 @@ const Op &SyncThreadsPartialOp();
* FenceProxyAsync() * FenceProxyAsync()
* *
*/ */
const Op &FenceProxyAsyncOp(); const Op &fence_proxy_async();
/*! /*!
* \brief Indicate arrival of warp issuing TMA_STORE * \brief Indicate arrival of warp issuing TMA_STORE
* *
* TMAStoreArrive() * tma_store_arrive()
* *
*/ */
const Op &TMAStoreArrive(); const Op &tma_store_arrive();
/*! /*!
* \brief Wait for TMA_STORE to finish * \brief Wait for TMA_STORE to finish
* *
* TMAStoreWait() * tma_store_wait()
* *
*/ */
const Op &TMAStoreWait(); const Op &tma_store_wait();
/*! /*!
* \brief Set reg hint for warp-specialized branched * \brief Set reg hint for warp-specialized branched
...@@ -182,23 +182,23 @@ const Op &TMAStoreWait(); ...@@ -182,23 +182,23 @@ const Op &TMAStoreWait();
* SetMaxNRegInc(num_reg, is_inc) * SetMaxNRegInc(num_reg, is_inc)
* *
*/ */
const Op &SetMaxNReg(); const Op &set_max_nreg();
/*! /*!
* \brief No set reg hint for warp-specialized branched * \brief No set reg hint for warp-specialized branched
* *
* NoSetMaxNReg() * no_set_max_nreg()
* *
*/ */
const Op &NoSetMaxNReg(); const Op &no_set_max_nreg();
/*! /*!
* \brief Wait the previous wgmma to finish * \brief Wait the previous wgmma to finish
* *
* WaitWgmma(num_mma) * wait_wgmma(num_mma)
* *
*/ */
const Op &WaitWgmma(); const Op &wait_wgmma();
/*! /*!
* \brief tvm intrinsic for amd matrix core mfma instructions. * \brief tvm intrinsic for amd matrix core mfma instructions.
......
...@@ -207,14 +207,14 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -207,14 +207,14 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
desc.smem_box.Set(0, PrimExpr(instruction_dim)); desc.smem_box.Set(0, PrimExpr(instruction_dim));
Call create_descriptor = Call create_descriptor =
Call(DataType::Handle(), CreateTMADescriptorOp(), desc.EncodeCallArgs()); Call(DataType::Handle(), create_tma_descriptor(), desc.EncodeCallArgs());
Array<PrimExpr> args; Array<PrimExpr> args;
args.reserve(desc.rank + 3); args.reserve(desc.rank + 3);
args.push_back(create_descriptor); args.push_back(create_descriptor);
if (is_load) if (is_load)
args.push_back(0); // mbarrier id placeholder args.push_back(0); // mbarrier id placeholder
auto op = is_load ? TMALoadOp() : TMAStoreOp(); auto op = is_load ? tma_load() : tma_store();
Stmt tma_copy; Stmt tma_copy;
...@@ -343,7 +343,7 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T, ...@@ -343,7 +343,7 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T,
} }
} }
Call create_desc = Call(DataType::Handle(), CreateTMAIm2ColDescriptorOp(), Call create_desc = Call(DataType::Handle(), create_tma_im2col_descriptor(),
desc.EncodeCallArgs()); desc.EncodeCallArgs());
Array<PrimExpr> global_coords; // c, w, h, n Array<PrimExpr> global_coords; // c, w, h, n
...@@ -394,7 +394,7 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T, ...@@ -394,7 +394,7 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T,
Stmt tma_copy = Stmt tma_copy =
IfThenElse(EQ(T.thread_var, 0), IfThenElse(EQ(T.thread_var, 0),
Evaluate(Call(DataType::Handle(), TMALoadIm2ColOp(), args))); Evaluate(Call(DataType::Handle(), tma_load_im2col(), args)));
return tma_copy; return tma_copy;
} }
......
...@@ -166,9 +166,10 @@ Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -166,9 +166,10 @@ Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
par_op->InferLayout( par_op->InferLayout(
{T.target, T.thread_bounds, T.layout_map, T.buffer_remap}, level); {T.target, T.thread_bounds, T.layout_map, T.buffer_remap}, level);
} }
auto loop_layout = par_op->GetLoopLayout();
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, auto thread_var = T.thread_var;
par_op->GetLoopLayout()); auto thread_loop =
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
vectorized_thread_loop = VectorizeLoop(thread_loop); vectorized_thread_loop = VectorizeLoop(thread_loop);
} }
...@@ -275,7 +276,7 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -275,7 +276,7 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
num = 2; num = 2;
Array<PrimExpr> args; Array<PrimExpr> args;
const Op &op = is_ldmatrix ? tl::LDMatrixOp() : tl::STMatrixOp(); const Op &op = is_ldmatrix ? tl::ptx_ldmatirx() : tl::ptx_stmatirx();
args.push_back(static_cast<int>(is_transposed)); args.push_back(static_cast<int>(is_transposed));
args.push_back(num); args.push_back(num);
...@@ -324,7 +325,7 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -324,7 +325,7 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
value1 = Cast(shared_tensor->dtype, value1); value1 = Cast(shared_tensor->dtype, value1);
} }
PrimExpr value_packed = PrimExpr value_packed =
Call(DataType::Int(32), PackB16Op(), {value0, value1}); Call(DataType::Int(32), pack_b16(), {value0, value1});
args.push_back(value_packed); args.push_back(value_packed);
} }
} }
......
...@@ -161,8 +161,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -161,8 +161,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
return {}; return {};
LayoutMap results; LayoutMap results;
ICHECK(C.scope() == "local.fragment"); ICHECK(C.scope() == "local.fragment");
auto block_size = *as_const_int(T.thread_bounds->extent) - auto block_size = *as_const_int(T.thread_bounds->extent);
*as_const_int(T.thread_bounds->min);
if (TargetIsVolta(T.target)) { if (TargetIsVolta(T.target)) {
const int warp_size = 32; const int warp_size = 32;
auto [warp_m, warp_n] = auto [warp_m, warp_n] =
......
...@@ -146,7 +146,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -146,7 +146,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (level == InferLevel::kStrict) if (level == InferLevel::kStrict)
return {}; return {};
auto block_size = T.thread_bounds->extent - T.thread_bounds->min; auto block_size = T.thread_bounds->extent;
// Step 1: try to infer loop's partition from a source fragment // Step 1: try to infer loop's partition from a source fragment
Buffer source_buffer, read_source_buffer; Buffer source_buffer, read_source_buffer;
for (const auto &[buffer, indices] : indice_map_) { for (const auto &[buffer, indices] : indice_map_) {
...@@ -228,7 +228,8 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -228,7 +228,8 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
} }
PrimExpr loop_thread_extent = loop_layout_->ThreadExtent(); PrimExpr loop_thread_extent = loop_layout_->ThreadExtent();
if (!analyzer_.CanProveEqual(loop_thread_extent, block_size)) if (!analyzer_.CanProveEqual(loop_thread_extent, block_size))
AddPredicate(LT(InputPlaceholder(0), loop_thread_extent)); AddPredicate(
LT(InputPlaceholder(0) - T.thread_bounds->min, loop_thread_extent));
} else { } else {
return {}; return {};
} }
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h> #include <tvm/tir/op_attr_types.h>
#include <tvm/tir/stmt_functor.h>
#include "../layout/utils.h" #include "../layout/utils.h"
#include "../transform/loop_partition.h" #include "../transform/loop_partition.h"
...@@ -307,7 +308,7 @@ Stmt CumSumOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -307,7 +308,7 @@ Stmt CumSumOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
this->src.scope() == "shared") { this->src.scope() == "shared") {
ICHECK(this->dst.scope() == "shared.dyn" || this->dst.scope() == "shared"); ICHECK(this->dst.scope() == "shared.dyn" || this->dst.scope() == "shared");
std::stringstream ss; std::stringstream ss;
auto threads = T.thread_bounds->extent - T.thread_bounds->min; auto threads = T.thread_bounds->extent;
ss << "tl::CumSum2D<" << threads << ", " << dim << ", " ss << "tl::CumSum2D<" << threads << ", " << dim << ", "
<< (reverse ? "true" : "false") << ">::run"; << (reverse ? "true" : "false") << ">::run";
Array<PrimExpr> args = {StringImm(ss.str()), src.access_ptr(1), Array<PrimExpr> args = {StringImm(ss.str()), src.access_ptr(1),
......
...@@ -807,7 +807,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -807,7 +807,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
std::string barrier_name = "_mbarrier"; std::string barrier_name = "_mbarrier";
this->stream << "__shared__ uint64_t " << barrier_name << "[" this->stream << "__shared__ uint64_t " << barrier_name << "["
<< barrier_count << "];\n"; << barrier_count << "];\n";
} else if (op->op.same_as(tl::GetMBarrierOp())) { } else if (op->op.same_as(tl::get_mbarrier())) {
std::string barrier_name = "_mbarrier"; std::string barrier_name = "_mbarrier";
std::string barrier_id = this->PrintExpr(op->args[0]); std::string barrier_id = this->PrintExpr(op->args[0]);
os << barrier_name + "[" + barrier_id + "]"; os << barrier_name + "[" + barrier_id + "]";
...@@ -819,50 +819,50 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -819,50 +819,50 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
print_extern_call_stmt("tl::mbarrier_arrive_expect_tx"); print_extern_call_stmt("tl::mbarrier_arrive_expect_tx");
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) { } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
print_extern_call_stmt("tl::mbarrier_cp_async_arrive"); print_extern_call_stmt("tl::mbarrier_cp_async_arrive");
} else if (op->op.same_as(tl::MBarrierExpectTX())) { } else if (op->op.same_as(tl::mbarrier_expect_tx())) {
print_extern_call_stmt("tl::mbarrier_expect_tx"); print_extern_call_stmt("tl::mbarrier_expect_tx");
} else if (op->op.same_as(tl::MBarrierWaitParity())) { } else if (op->op.same_as(tl::mbarrier_wait_parity())) {
print_extern_call_stmt("tl::mbarrier_wait"); print_extern_call_stmt("tl::mbarrier_wait");
} else if (op->op.same_as(tl::SyncThreadsPartialOp())) { } else if (op->op.same_as(tl::sync_thread_partial())) {
print_extern_call_stmt("tl::syncthreads_partial"); print_extern_call_stmt("tl::syncthreads_partial");
} else if (op->op.same_as(tl::TMALoadOp())) { } else if (op->op.same_as(tl::tma_load())) {
print_extern_call_stmt("tl::tma_load"); print_extern_call_stmt("tl::tma_load");
} else if (op->op.same_as(tl::TMALoadIm2ColOp())) { } else if (op->op.same_as(tl::tma_load_im2col())) {
print_extern_call_stmt("tl::tma_load_im2col"); print_extern_call_stmt("tl::tma_load_im2col");
} else if (op->op.same_as(tl::TMAStoreOp())) { } else if (op->op.same_as(tl::tma_store())) {
print_extern_call_stmt("tl::tma_store"); print_extern_call_stmt("tl::tma_store");
} else if (op->op.same_as(tl::LDMatrixOp())) { } else if (op->op.same_as(tl::ptx_ldmatirx())) {
int trans = Downcast<IntImm>(op->args[0])->value; int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value; int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num); std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num);
if (trans == 1) if (trans == 1)
func_name += "_trans"; func_name += "_trans";
print_extern_call_stmt(func_name, 2); print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::STMatrixOp())) { } else if (op->op.same_as(tl::ptx_stmatirx())) {
int trans = Downcast<IntImm>(op->args[0])->value; int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value; int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num); std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num);
if (trans == 1) if (trans == 1)
func_name += "_trans"; func_name += "_trans";
print_extern_call_stmt(func_name, 2); print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::FenceProxyAsyncOp())) { } else if (op->op.same_as(tl::fence_proxy_async())) {
print_extern_call_stmt("tl::fence_proxy_async"); print_extern_call_stmt("tl::fence_proxy_async");
} else if (op->op.same_as(tl::TMAStoreArrive())) { } else if (op->op.same_as(tl::tma_store_arrive())) {
print_extern_call_stmt("tl::tma_store_arrive"); print_extern_call_stmt("tl::tma_store_arrive");
} else if (op->op.same_as(tl::TMAStoreWait())) { } else if (op->op.same_as(tl::tma_store_wait())) {
print_extern_call_stmt("tl::tma_store_wait<0>"); print_extern_call_stmt("tl::tma_store_wait<0>");
} else if (op->op.same_as(tl::SetMaxNReg())) { } else if (op->op.same_as(tl::set_max_nreg())) {
this->PrintIndent(); this->PrintIndent();
int nreg = Downcast<IntImm>(op->args[0])->value; int nreg = Downcast<IntImm>(op->args[0])->value;
int is_inc = Downcast<IntImm>(op->args[1])->value; int is_inc = Downcast<IntImm>(op->args[1])->value;
std::string func_name = std::string func_name =
is_inc ? "tl::warpgroup_reg_alloc" : "tl::warpgroup_reg_dealloc"; is_inc ? "tl::warpgroup_reg_alloc" : "tl::warpgroup_reg_dealloc";
this->stream << func_name << "<" << std::to_string(nreg) << ">();\n"; this->stream << func_name << "<" << std::to_string(nreg) << ">();\n";
} else if (op->op.same_as(tl::WaitWgmma())) { } else if (op->op.same_as(tl::wait_wgmma())) {
this->PrintIndent(); this->PrintIndent();
int num_mma = Downcast<IntImm>(op->args[0])->value; int num_mma = Downcast<IntImm>(op->args[0])->value;
this->stream << "tl::wait_wgmma<" << std::to_string(num_mma) << ">();\n"; this->stream << "tl::wait_wgmma<" << std::to_string(num_mma) << ">();\n";
} else if (op->op.same_as(tl::PackB16Op())) { } else if (op->op.same_as(tl::pack_b16())) {
os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", " os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", "
<< this->PrintExpr(op->args[1]) << ")"; << this->PrintExpr(op->args[1]) << ")";
} else if (op->op.same_as(builtin::tvm_fill_fragment())) { } else if (op->op.same_as(builtin::tvm_fill_fragment())) {
......
...@@ -765,7 +765,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -765,7 +765,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
std::string barrier_name = "_mbarrier"; std::string barrier_name = "_mbarrier";
this->stream << "__shared__ uint64_t " << barrier_name << "[" this->stream << "__shared__ uint64_t " << barrier_name << "["
<< barrier_count << "];\n"; << barrier_count << "];\n";
} else if (op->op.same_as(tl::GetMBarrierOp())) { } else if (op->op.same_as(tl::get_mbarrier())) {
std::string barrier_name = "_mbarrier"; std::string barrier_name = "_mbarrier";
std::string barrier_id = this->PrintExpr(op->args[0]); std::string barrier_id = this->PrintExpr(op->args[0]);
os << barrier_name + "[" + barrier_id + "]"; os << barrier_name + "[" + barrier_id + "]";
...@@ -777,46 +777,46 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -777,46 +777,46 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
print_extern_call_stmt("tl::mbarrier_arrive_expect_tx"); print_extern_call_stmt("tl::mbarrier_arrive_expect_tx");
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) { } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
print_extern_call_stmt("tl::mbarrier_cp_async_arrive"); print_extern_call_stmt("tl::mbarrier_cp_async_arrive");
} else if (op->op.same_as(tl::MBarrierExpectTX())) { } else if (op->op.same_as(tl::mbarrier_expect_tx())) {
print_extern_call_stmt("tl::mbarrier_expect_tx"); print_extern_call_stmt("tl::mbarrier_expect_tx");
} else if (op->op.same_as(tl::MBarrierWaitParity())) { } else if (op->op.same_as(tl::mbarrier_wait_parity())) {
print_extern_call_stmt("tl::mbarrier_wait"); print_extern_call_stmt("tl::mbarrier_wait");
} else if (op->op.same_as(tl::SyncThreadsPartialOp())) { } else if (op->op.same_as(tl::sync_thread_partial())) {
print_extern_call_stmt("tl::syncthreads_partial"); print_extern_call_stmt("tl::syncthreads_partial");
} else if (op->op.same_as(tl::TMALoadOp())) { } else if (op->op.same_as(tl::tma_load())) {
print_extern_call_stmt("tl::tma_load"); print_extern_call_stmt("tl::tma_load");
} else if (op->op.same_as(tl::TMALoadIm2ColOp())) { } else if (op->op.same_as(tl::tma_load_im2col())) {
print_extern_call_stmt("tl::tma_load_im2col"); print_extern_call_stmt("tl::tma_load_im2col");
} else if (op->op.same_as(tl::TMAStoreOp())) { } else if (op->op.same_as(tl::tma_store())) {
print_extern_call_stmt("tl::tma_store"); print_extern_call_stmt("tl::tma_store");
} else if (op->op.same_as(tl::LDMatrixOp())) { } else if (op->op.same_as(tl::ptx_ldmatirx())) {
int trans = Downcast<IntImm>(op->args[0])->value; int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value; int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num); std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num);
if (trans == 1) if (trans == 1)
func_name += "_trans"; func_name += "_trans";
print_extern_call_stmt(func_name, 2); print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::STMatrixOp())) { } else if (op->op.same_as(tl::ptx_stmatirx())) {
int trans = Downcast<IntImm>(op->args[0])->value; int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value; int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num); std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num);
if (trans == 1) if (trans == 1)
func_name += "_trans"; func_name += "_trans";
print_extern_call_stmt(func_name, 2); print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::FenceProxyAsyncOp())) { } else if (op->op.same_as(tl::fence_proxy_async())) {
print_extern_call_stmt("tl::fence_proxy_async"); print_extern_call_stmt("tl::fence_proxy_async");
} else if (op->op.same_as(tl::SetMaxNReg())) { } else if (op->op.same_as(tl::set_max_nreg())) {
this->PrintIndent(); this->PrintIndent();
int nreg = Downcast<IntImm>(op->args[0])->value; int nreg = Downcast<IntImm>(op->args[0])->value;
int is_inc = Downcast<IntImm>(op->args[1])->value; int is_inc = Downcast<IntImm>(op->args[1])->value;
std::string func_name = std::string func_name =
is_inc ? "tl::warpgroup_reg_alloc" : "tl::warpgroup_reg_dealloc"; is_inc ? "tl::warpgroup_reg_alloc" : "tl::warpgroup_reg_dealloc";
this->stream << func_name << "<" << std::to_string(nreg) << ">();\n"; this->stream << func_name << "<" << std::to_string(nreg) << ">();\n";
} else if (op->op.same_as(tl::WaitWgmma())) { } else if (op->op.same_as(tl::wait_wgmma())) {
this->PrintIndent(); this->PrintIndent();
int num_mma = Downcast<IntImm>(op->args[0])->value; int num_mma = Downcast<IntImm>(op->args[0])->value;
this->stream << "tl::wait_wgmma<" << std::to_string(num_mma) << ">();\n"; this->stream << "tl::wait_wgmma<" << std::to_string(num_mma) << ">();\n";
} else if (op->op.same_as(tl::PackB16Op())) { } else if (op->op.same_as(tl::pack_b16())) {
os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", " os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", "
<< this->PrintExpr(op->args[1]) << ")"; << this->PrintExpr(op->args[1]) << ")";
} else if (op->op.same_as(builtin::tvm_fill_fragment())) { } else if (op->op.same_as(builtin::tvm_fill_fragment())) {
......
/*!
* \file eliminate_storage_sync_for_mbarrier.cc
*/
#include "./storage_access.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"
#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
namespace tvm {
namespace tl {
using namespace tir;
using arith::IRMutatorWithAnalyzer;
using arith::IRVisitorWithAnalyzer;
class Eliminator : public IRMutatorWithAnalyzer {
public:
static Stmt Substitute(Stmt stmt, bool skip_thread_partition = false) {
arith::Analyzer analyzer;
Eliminator transformer(&analyzer);
return transformer.VisitStmt(stmt);
}
Eliminator(arith::Analyzer *analyzer) : IRMutatorWithAnalyzer(analyzer) {
in_mbarrier_region_ = false;
}
Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == "thread_extent") {
const VarNode *var = nullptr;
if (op->node->IsInstance<VarNode>()) {
var = static_cast<const VarNode *>(op->node.get());
if (var->name_hint == "threadIdx.x") {
thread_extent_ = op;
}
}
}
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
Stmt VisitStmt_(const EvaluateNode *op) final {
const CallNode *call = nullptr;
if (op->value->IsInstance<CallNode>()) {
call = static_cast<const CallNode *>(op->value.get());
if (call->op.same_as(builtin::tvm_storage_sync())) {
// Skip storage sync if we're in a region with mbarrier operations
if (in_mbarrier_region_) {
return Stmt();
}
} else if (call->op.same_as(builtin::ptx_arrive_barrier()) ||
call->op.same_as(builtin::ptx_wait_barrier())) {
in_mbarrier_region_ = true;
}
}
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
Stmt VisitStmt_(const IfThenElseNode *op) final {
bool old_in_mbarrier = in_mbarrier_region_;
Stmt then_case = VisitStmt(op->then_case);
Stmt ret;
if (op->else_case.defined()) {
in_mbarrier_region_ = old_in_mbarrier;
Stmt else_case = VisitStmt(op->else_case.value());
in_mbarrier_region_ = old_in_mbarrier || in_mbarrier_region_;
ret = IfThenElse(VisitExpr(op->condition), then_case, else_case);
} else {
in_mbarrier_region_ = old_in_mbarrier || in_mbarrier_region_;
ret = IfThenElse(VisitExpr(op->condition), then_case, Stmt());
}
return ret;
}
private:
bool in_mbarrier_region_;
const AttrStmtNode *thread_extent_{nullptr};
};
using namespace tir::transform;
namespace transform {
tvm::transform::Pass EliminateStorageSyncForMBarrier() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto *n = f.CopyOnWrite();
n->body = Eliminator::Substitute(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.EliminateStorageSyncForMBarrier",
{});
}
TVM_REGISTER_GLOBAL("tl.transform.EliminateStorageSyncForMBarrier")
.set_body_typed(EliminateStorageSyncForMBarrier);
} // namespace transform
} // namespace tl
} // namespace tvm
...@@ -56,7 +56,8 @@ public: ...@@ -56,7 +56,8 @@ public:
void VisitStmt_(const EvaluateNode *op) final { void VisitStmt_(const EvaluateNode *op) final {
Proxy proxy = Proxy::kAsync; Proxy proxy = Proxy::kAsync;
if (auto call = op->value.as<CallNode>()) { if (auto call = op->value.as<CallNode>()) {
if (call->op.same_as(LDMatrixOp()) || call->op.same_as(STMatrixOp())) { if (call->op.same_as(ptx_ldmatirx()) ||
call->op.same_as(ptx_stmatirx())) {
proxy = Proxy::kGeneric; proxy = Proxy::kGeneric;
} }
} }
...@@ -123,13 +124,13 @@ public: ...@@ -123,13 +124,13 @@ public:
private: private:
Stmt VisitStmt_(const EvaluateNode *op) final { Stmt VisitStmt_(const EvaluateNode *op) final {
if (auto call = op->value.as<CallNode>()) { if (auto call = op->value.as<CallNode>()) {
if (call->op.same_as(TMAStoreOp())) { if (call->op.same_as(tma_store())) {
Array<Stmt> new_body; Array<Stmt> new_body;
new_body.push_back(GetRef<Evaluate>(op)); new_body.push_back(GetRef<Evaluate>(op));
new_body.push_back( new_body.push_back(
Evaluate(Call(DataType::Handle(), TMAStoreArrive(), {}))); Evaluate(Call(DataType::Handle(), tma_store_arrive(), {})));
new_body.push_back( new_body.push_back(
Evaluate(Call(DataType::Handle(), TMAStoreWait(), {}))); Evaluate(Call(DataType::Handle(), tma_store_wait(), {})));
return SeqStmt(std::move(new_body)); return SeqStmt(std::move(new_body));
} }
} }
...@@ -157,7 +158,7 @@ private: ...@@ -157,7 +158,7 @@ private:
Array<Stmt> new_body; Array<Stmt> new_body;
Proxy cur_proxy, prev_proxy; Proxy cur_proxy, prev_proxy;
auto fence_stmt = auto fence_stmt =
Evaluate(Call(DataType::Handle(), FenceProxyAsyncOp(), {})); Evaluate(Call(DataType::Handle(), fence_proxy_async(), {}));
prev_proxy = get_generic_proxy(op->seq[0]); prev_proxy = get_generic_proxy(op->seq[0]);
new_body.push_back(VisitStmt(op->seq[0])); new_body.push_back(VisitStmt(op->seq[0]));
if (op->seq.size() > 1) { if (op->seq.size() > 1) {
......
...@@ -256,7 +256,6 @@ public: ...@@ -256,7 +256,6 @@ public:
auto &next = infer_list_[cur_infer_id]; auto &next = infer_list_[cur_infer_id];
auto iter_var = thread_var_vec_[cur_infer_id]; auto iter_var = thread_var_vec_[cur_infer_id];
auto thread_bounds = thread_bounds_vec_[cur_infer_id]; auto thread_bounds = thread_bounds_vec_[cur_infer_id];
// Double-check that 'next' is valid // Double-check that 'next' is valid
ICHECK(next != nullptr) << "infer_list_[" << cur_infer_id ICHECK(next != nullptr) << "infer_list_[" << cur_infer_id
<< "] is null inside run_infer_step."; << "] is null inside run_infer_step.";
...@@ -420,9 +419,10 @@ private: ...@@ -420,9 +419,10 @@ private:
auto const_int_bound = analyzer_.const_int_bound(thread_var_); auto const_int_bound = analyzer_.const_int_bound(thread_var_);
auto min_value = const_int_bound->min_value; auto min_value = const_int_bound->min_value;
auto max_value = const_int_bound->max_value; auto max_value = const_int_bound->max_value;
auto extent = max_value - min_value + 1;
auto dtype = thread_var_->var.dtype(); auto dtype = thread_var_->var.dtype();
thread_bounds_vec_.push_back(Range::FromMinExtent( thread_bounds_vec_.push_back(Range::FromMinExtent(
IntImm(dtype, min_value), IntImm(dtype, max_value + 1))); IntImm(dtype, min_value), IntImm(dtype, extent)));
} else { } else {
thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1)); thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
} }
...@@ -458,9 +458,10 @@ private: ...@@ -458,9 +458,10 @@ private:
analyzer_.const_int_bound.IsBound(thread_var_->var)) { analyzer_.const_int_bound.IsBound(thread_var_->var)) {
auto const_int_bound = analyzer_.const_int_bound(thread_var_); auto const_int_bound = analyzer_.const_int_bound(thread_var_);
auto dtype = thread_var_->var.dtype(); auto dtype = thread_var_->var.dtype();
auto extent =
const_int_bound->max_value - const_int_bound->min_value + 1;
thread_bounds_vec_.push_back(Range::FromMinExtent( thread_bounds_vec_.push_back(Range::FromMinExtent(
IntImm(dtype, const_int_bound->min_value), IntImm(dtype, const_int_bound->min_value), IntImm(dtype, extent)));
IntImm(dtype, const_int_bound->max_value + 1)));
} else { } else {
thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1)); thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
} }
...@@ -568,9 +569,9 @@ private: ...@@ -568,9 +569,9 @@ private:
} }
}); });
auto loop_layout = result_.for_map[root];
bool parallel_loop = !is_register_store && !skip_thread_partition_; bool parallel_loop = !is_register_store && !skip_thread_partition_;
if (parallel_loop) { if (parallel_loop) {
auto loop_layout = result_.for_map[root];
for_node = for_node =
PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout); PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
} }
......
...@@ -188,8 +188,7 @@ Fragment PlanLoopPartition(For op, size_t num_thread, int vectorize_size) { ...@@ -188,8 +188,7 @@ Fragment PlanLoopPartition(For op, size_t num_thread, int vectorize_size) {
} }
Fragment PlanLoopPartition(For op, int vectorize_size, Range thread_range) { Fragment PlanLoopPartition(For op, int vectorize_size, Range thread_range) {
size_t num_thread = size_t num_thread = *as_const_int(thread_range->extent);
*as_const_int(thread_range->extent) - *as_const_int(thread_range->min);
LoopPartitioner partitioner; LoopPartitioner partitioner;
Fragment fragment = partitioner.Partition(op, num_thread, vectorize_size); Fragment fragment = partitioner.Partition(op, num_thread, vectorize_size);
auto node = make_object<FragmentNode>(*fragment.get()); auto node = make_object<FragmentNode>(*fragment.get());
......
...@@ -64,6 +64,7 @@ private: ...@@ -64,6 +64,7 @@ private:
void VisitStmt_(const ForNode *node) final { void VisitStmt_(const ForNode *node) final {
inner_for_ = node; inner_for_ = node;
iter_map_.Set(node->loop_var, Range(node->min, node->extent)); iter_map_.Set(node->loop_var, Range(node->min, node->extent));
arith::IRVisitorWithAnalyzer::VisitStmt_(node); arith::IRVisitorWithAnalyzer::VisitStmt_(node);
} }
...@@ -138,7 +139,6 @@ private: ...@@ -138,7 +139,6 @@ private:
max_vector_size = gcd_base; max_vector_size = gcd_base;
} }
vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_); vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);
PrimExpr elem_offset = 0; PrimExpr elem_offset = 0;
PrimExpr stride = 1; PrimExpr stride = 1;
for (int i = indices.size() - 1; i >= 0; --i) { for (int i = indices.size() - 1; i >= 0; --i) {
...@@ -232,6 +232,7 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size, ...@@ -232,6 +232,7 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size,
ICHECK(target_vectorized_size >= 1); ICHECK(target_vectorized_size >= 1);
if (target_vectorized_size == 1) if (target_vectorized_size == 1)
return true; return true;
// bind thread range
if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_vectorized_size), if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_vectorized_size),
0)) 0))
return false; return false;
...@@ -241,10 +242,11 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size, ...@@ -241,10 +242,11 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size,
iter_var_size, target_vectorized_size)))); iter_var_size, target_vectorized_size))));
PrimExpr expr_transformed = analyzer->Simplify( PrimExpr expr_transformed = analyzer->Simplify(
Substitute(expr, {{var, v0 + v1 * target_vectorized_size}})); Substitute(expr, {{var, v0 + v1 * target_vectorized_size}}));
PrimExpr expr_simplified = analyzer->Simplify(expr_transformed);
Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size)); Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size));
PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed); PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed);
// This simplify is necessary for thread region specifiled
// optimizations.
expr_vectorized = analyzer->Simplify(expr_vectorized);
auto ramp_node = expr_vectorized.as<RampNode>(); auto ramp_node = expr_vectorized.as<RampNode>();
if (!ramp_node) { if (!ramp_node) {
// Broadcast value // Broadcast value
......
...@@ -34,6 +34,7 @@ namespace tl { ...@@ -34,6 +34,7 @@ namespace tl {
using namespace tir; using namespace tir;
int GetVectorizeSize(const For &loop); int GetVectorizeSize(const For &loop);
For VectorizeLoop(const For &loop, int vectorize_hint = -1); For VectorizeLoop(const For &loop, int vectorize_hint = -1);
bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size, bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size,
......
...@@ -49,9 +49,9 @@ public: ...@@ -49,9 +49,9 @@ public:
Call alloc_desc = Call(DataType::Handle(), builtin::tvm_stack_alloca(), Call alloc_desc = Call(DataType::Handle(), builtin::tvm_stack_alloca(),
{StringImm("arg_value"), 16}); {StringImm("arg_value"), 16});
Array<PrimExpr> init_desc_args; Array<PrimExpr> init_desc_args;
if (call->op.same_as(CreateTMADescriptorOp())) { if (call->op.same_as(create_tma_descriptor())) {
init_desc_args.push_back(StringImm(tvm_tensormap_create_tiled)); init_desc_args.push_back(StringImm(tvm_tensormap_create_tiled));
} else if (call->op.same_as(CreateTMAIm2ColDescriptorOp())) { } else if (call->op.same_as(create_tma_im2col_descriptor())) {
init_desc_args.push_back(StringImm(tvm_tensormap_create_im2col)); init_desc_args.push_back(StringImm(tvm_tensormap_create_im2col));
} else { } else {
CHECK(0) << call->op; CHECK(0) << call->op;
...@@ -112,8 +112,8 @@ public: ...@@ -112,8 +112,8 @@ public:
} }
PrimExpr VisitExpr_(const CallNode *call) final { PrimExpr VisitExpr_(const CallNode *call) final {
if (call->op.same_as(CreateTMADescriptorOp()) || if (call->op.same_as(create_tma_descriptor()) ||
call->op.same_as(CreateTMAIm2ColDescriptorOp())) { call->op.same_as(create_tma_im2col_descriptor())) {
Var var; Var var;
auto iter = desc_map_.find(GetRef<Call>(call)); auto iter = desc_map_.find(GetRef<Call>(call));
if (iter != desc_map_.end()) { if (iter != desc_map_.end()) {
...@@ -128,24 +128,24 @@ public: ...@@ -128,24 +128,24 @@ public:
{StringImm("tl::prefetch_tma_descriptor"), var}))); {StringImm("tl::prefetch_tma_descriptor"), var})));
} }
return var; return var;
} else if (call->op.same_as(CreateListofMBarrierOp())) { } else if (call->op.same_as(create_list_of_mbarrier())) {
ICHECK(init_mbarrier_calls_.size() == 0); ICHECK(init_mbarrier_calls_.size() == 0);
int num_barriers = static_cast<int>(call->args.size()); int num_barriers = static_cast<int>(call->args.size());
for (int i = 0; i < num_barriers; i++) { for (int i = 0; i < num_barriers; i++) {
PrimExpr mbarrier = Call(DataType::Handle(), GetMBarrierOp(), {i}); PrimExpr mbarrier = Call(DataType::Handle(), get_mbarrier(), {i});
init_mbarrier_calls_.push_back(Evaluate( init_mbarrier_calls_.push_back(Evaluate(
Call(DataType::Handle(), builtin::ptx_init_barrier_thread_count(), Call(DataType::Handle(), builtin::ptx_init_barrier_thread_count(),
{mbarrier, call->args[i]}))); {mbarrier, call->args[i]})));
} }
return 0; return 0;
} else if (call->op.same_as(SyncThreadsPartialOp())) { } else if (call->op.same_as(sync_thread_partial())) {
int barrier_id = init_mbarrier_calls_.size(); int barrier_id = init_mbarrier_calls_.size();
PrimExpr mbarrier = PrimExpr mbarrier =
Call(DataType::Handle(), GetMBarrierOp(), {barrier_id}); Call(DataType::Handle(), get_mbarrier(), {barrier_id});
init_mbarrier_calls_.push_back(Evaluate( init_mbarrier_calls_.push_back(Evaluate(
Call(DataType::Handle(), builtin::ptx_init_barrier_thread_count(), Call(DataType::Handle(), builtin::ptx_init_barrier_thread_count(),
{mbarrier, call->args[0]}))); {mbarrier, call->args[0]})));
return Call(DataType::Handle(), SyncThreadsPartialOp(), {mbarrier}); return Call(DataType::Handle(), sync_thread_partial(), {mbarrier});
} else { } else {
return StmtExprMutator::VisitExpr_(call); return StmtExprMutator::VisitExpr_(call);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment