Unverified Commit 5c11d245 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Merge bulk copy into copy and improve layout inference for bulk copy (#746)

* [Refactor] Merge bulk copy into copy and refactor layout inference for bulk copy

* Deleted the `bulk_copy` operator implementation and its header file as it is no longer needed.
* Introduced a new function `cuTensorMapType()` to return the data type for CUDA tensor mapping.
* Updated related files to reflect these changes, ensuring that the codebase remains clean and maintainable.

* lint fix

* Fix typos in intrinsic names and remove unused print statement in block_sparse_attn_tilelang.py. Updated references from `ptx_ldmatirx` to `ptx_ldmatrix` across multiple files for consistency.

* remove bulk copy

* Refactor copy and atomic add operations to support TMA lower configuration

- Updated `GetCopyInst` to accept a `disable_tma_lower` parameter, allowing for conditional usage of TMA in bulk load/store operations.
- Modified `Lower` method in `Copy` to incorporate the new TMA configuration.
- Refactored `AtomicAdd::Lower` to streamline layout inference and vectorization logic.
- Removed unused `disable_tma_lower` field from `LowerArgs` structure for clarity.
- Enhanced atomic add vectorization by replacing the buggy implementation with a more robust loop vectorization approach.

* Enhance TMA bulk copy logic in `LowerBulkCopy` method

- Added a condition to set `desc.swizzle` to `CU_TENSOR_MAP_SWIZZLE_NONE` when `shared_layout` matches `linear_layout`, improving clarity in layout handling.
- Updated warning log to provide more detailed information about fallback scenarios, including source and destination buffer names and shapes, enhancing debugging capabilities.

* lint fix

* Remove fallback logging for non-swizzled global layout in `LowerBulkCopy` method to streamline the bulk copy logic. This change enhances code clarity by eliminating unnecessary warning messages related to inner box dimensions.

* Enhance reshape kernel compilation in `run_reshape` and `run_reshape_smem_1d_2_2d` functions

- Updated the `tl.compile` method to include `pass_configs` that disable TMA lower and warp specialization, addressing shared memory layout transformation limitations.
- Added TODO comments to indicate the need for further improvements in shared memory handling.

* Update `native_sparse_attention` function to include TMA configuration options

- Added `pass_configs` to the JIT decorator to disable TMA lower and warp specialization, addressing potential issues with shared memory layout transformations.
- Updated comments to clarify modifications in tensor shapes for inference, specifically setting `q` sequence length to 1.

* Refactor JIT decorator formatting in `native_sparse_attention` function

- Improved readability by reformatting the JIT decorator parameters for `native_sparse_attention`, ensuring consistent style across the codebase.
- No functional changes were made; this update focuses on code clarity and maintainability.

* Enhance thread management and logging in TileLang compilation

- Added a method to check if printing is enabled during compilation, improving control over logging behavior.
- Updated the JIT kernel class to utilize the new method for logging compilation status, ensuring consistent and clear output.
- Added comments to clarify the purpose of changes and improve code readability.

* Add warp specialization scope and refactor register management in TileLang

- Introduced a new constant `kWarpSpecializationScope` in `builtin.h` for better attribute management.
- Removed the `SetMaxNRegCollector` class and its related logic from `warp_specialized_rewriter.cc`, streamlining the warp specialization process.
- Added functions `annotate_producer_reg_dealloc` and `annotate_consumer_reg_alloc` in `builtin.py` to facilitate register management.
- Implemented `AnnotateWarpGroupRegAlloc` in `__init__.py` to inject register allocation calls into warp-specialized functions, enhancing the overall register handling in the compilation process.

* Refactor test for InjectSetMaxNReg pass in TileLang

- Improved readability by restructuring conditional checks and assertions in the test cases.
- Enhanced clarity in the collection of `set_max_nreg` calls by simplifying the logic.
- Ensured consistent formatting and spacing throughout the test functions for better maintainability.

* Enhance bulk copy and store checks in `Copy` class

- Updated scope validation for source and destination tensors in `CheckBulkLoad` and `CheckBulkStore` methods to include both `shared.dyn` and `shared` as valid options.
- Modified `CheckLDSMCopy` and `CheckSTSMCopy` methods to accommodate the new scope validation, ensuring compatibility with shared memory configurations.
- Improved logging in `LowerBulkCopy` to provide clearer warnings regarding unsupported swizzle layouts, including source and destination names for better debugging.

* lint fix
parent cb37bfef
......@@ -192,7 +192,7 @@ def matmul_sp(M, N, K):
# Clear out the accumulation buffer
T.clear(C_local)
T.no_set_max_nreg()
T.disable_warp_group_reg_alloc()
T.use_swizzle(panel_size=10, enable=enable_rasterization)
T.annotate_layout({
......
......@@ -52,7 +52,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
T.no_set_max_nreg()
T.disable_warp_group_reg_alloc()
loop_range = T.ceildiv(seqlen_kv, block_N)
for k in T.Pipelined(loop_range, num_stages=2):
T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], qKV_shared)
......
......@@ -8,7 +8,14 @@ import tilelang.testing
tilelang.testing.set_random_seed(42)
@tilelang.jit(out_idx=[-1])
# TODO(lei): workaround, as threads is not divisible by warp group size,
# auto warp specialization may have some bugs.
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
def native_sparse_attention(
batch,
heads,
......@@ -22,7 +29,7 @@ def native_sparse_attention(
if scale is None:
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
# Modified shapes for inference (q has seq_len=1)
# Modified shapes for inference (q has seq_len=1)a
q_shape = [batch, 1, heads, dim] # Changed seq_len to 1
kv_shape = [batch, seq_len, head_kv, dim]
block_indices_shape = [batch, 1, head_kv, selected_blocks] # Changed seq_len to 1
......@@ -167,8 +174,6 @@ def main():
block_counts=block_counts,
block_size=block_size,
)
print("out", out)
print("ref", ref)
torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2)
......
......@@ -338,7 +338,7 @@ def matmul(M,
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
})
if threads == 512:
T.no_set_max_nreg()
T.disable_warp_group_reg_alloc()
T.clear(C_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
......
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import argparse
......@@ -340,11 +339,10 @@ def main(BATCH: int = 1,
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None
assert torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
def run():
O_ref.backward(dO, retain_graph=True)
......
......@@ -122,7 +122,7 @@ def tilelang_chunk_fwd_o(
T.clear(A_fragment)
T.clear(O_fragment)
T.no_set_max_nreg()
T.disable_warp_group_reg_alloc()
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy(
Q[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
......
......@@ -101,7 +101,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
})
T.fill(A_fragment, 0)
T.no_set_max_nreg()
T.disable_warp_group_reg_alloc()
for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
......
......@@ -107,7 +107,7 @@ def tilelang_recompute_w_u_fwd(
U_Beta_shared: tilelang.layout.make_swizzled_layout(U_Beta_shared),
})
T.no_set_max_nreg()
T.disable_warp_group_reg_alloc()
for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
G_shared[i_s] = T.exp(G[bb, bs * block_S + i_s, bh])
......
......@@ -178,7 +178,6 @@ def test_topk_sparse_attention():
# Run tilelang kernel
kernel = blocksparse_flashattn(
BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
print(kernel.get_kernel_source())
tilelang_output = kernel(q, k, v, block_mask.to(torch.int8))
# Compute reference
......
......@@ -182,27 +182,25 @@ For AtomicAdd::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Stmt AtomicAdd::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Target target = T.target;
bool is_cpu_target = target->GetTargetDeviceType() == kDLCPU;
auto simt_loop = MakeSIMTLoop(analyzer);
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
For vectorized_thread_loop;
auto par_op = std::make_unique<ParallelOp>(fused_loop);
if (!is_cpu_target) {
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
InferLevel::kFree};
for (auto level : levels) {
par_op->InferLayout(
{T.target, T.thread_bounds, T.layout_map, T.buffer_remap}, level);
}
auto loop_layout = par_op->GetLoopLayout();
Var thread_var = T.thread_var;
Range thread_bounds = T.thread_bounds;
auto thread_loop =
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
vectorized_thread_loop = VectorizeAtomicAdd(
thread_loop, thread_var, thread_bounds, GetArchInt(target));
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
InferLevel::kFree};
for (auto level : levels) {
par_op->InferLayout(
{T.target, T.thread_bounds, T.layout_map, T.buffer_remap}, level);
}
auto loop_layout = par_op->GetLoopLayout();
Var thread_var = T.thread_var;
Range thread_bounds = T.thread_bounds;
auto thread_loop =
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
// TODO(@dyq): buggy implementation, need to fix
// vectorized_thread_loop = VectorizeAtomicAdd(
// thread_loop, thread_var, thread_bounds, GetArchInt(target));
auto vectorized_thread_loop = VectorizeLoop(thread_loop);
if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
......
......@@ -29,6 +29,8 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kPtxasRegisterUsageLevel, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnablePTXASVerboseOutput, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool);
DataType cuTensorMapType() { return DataType::UInt(8, 128); }
#define TIR_DEFINE_TL_BUILTIN(OpName) \
const Op &OpName() { \
static const Op &op = Op::Get("tl." #OpName); \
......@@ -78,7 +80,7 @@ TIR_DEFINE_TL_BUILTIN(mbarrier_expect_tx)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_ldmatirx)
TIR_DEFINE_TL_BUILTIN(ptx_ldmatrix)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
......
......@@ -15,6 +15,8 @@ namespace tl {
namespace attr {
static constexpr const char *kPaddingMap = "padding_map";
static constexpr const char *kWarpSpecializationScope =
"kWarpSpecializationScope";
} // namespace attr
static constexpr const char *kDebugMergeSharedMemoryAllocations =
......@@ -54,6 +56,14 @@ static constexpr const char *kDisableDynamicTailSplit =
*/
static constexpr const char *kDynamicAlignment = "tl.dynamic_alignment";
/*!
* \brief Get the type of the CUDA tensor map
*
* DataType cuTensorMapType()
*
*/
DataType cuTensorMapType();
/*!
* \brief tvm intrinsics for TMADescriptor creation for tiled load
*
......@@ -138,15 +148,15 @@ TVM_DLL const Op &mbarrier_expect_tx();
/*!
* \brief tvm intrinsics for ldmatrix
*
* ptx_ldmatirx(transposed, num, shared_addr, local_addr)
* ptx_ldmatrix(transposed, num, shared_addr, local_addr)
*
*/
TVM_DLL const Op &ptx_ldmatirx();
TVM_DLL const Op &ptx_ldmatrix();
/*!
* \brief tvm intrinsics for stmatrix
*
* ptx_ldmatirx(transposed, num, shared_addr, int32_values...)
* ptx_ldmatrix(transposed, num, shared_addr, int32_values...)
*
*/
TVM_DLL const Op &ptx_stmatrix();
......
/*!
* \file tl/op/bulk_copy.h
* \brief Bulk copy operator.
*
*/
#ifndef TVM_TL_OP_BULK_COPY_H_
#define TVM_TL_OP_BULK_COPY_H_
#include "elem.h"
namespace tvm {
namespace tl {
using namespace tir;
struct TMADesc {
size_t rank;
int data_type;
Array<PrimExpr> global_shape, global_stride;
Array<PrimExpr> smem_box, smem_stride;
PrimExpr global_addr;
int swizzle;
int interleave;
int oob_fill;
int l2_promotion;
Array<PrimExpr> EncodeCallArgs() const;
};
DataType cuTensorMapType();
struct TMAIm2ColDesc {
size_t rank;
int data_type;
Array<PrimExpr> global_shape, global_stride, elem_stride; // rank
Array<PrimExpr> lower_corner, upper_corner; // rank - 2
PrimExpr global_addr;
int smem_box_pixel, smem_box_channel;
int swizzle;
int interleave;
int oob_fill;
int l2_promotion;
Array<PrimExpr> EncodeCallArgs() const;
};
class Conv2DIm2ColOp : public Operator {
public:
Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
static const Op &Get();
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<Conv2DIm2ColOp>(*this);
}
private:
Buffer src, dst;
int stride, padding, dilation, kernel, eviction_policy;
PrimExpr nhw_step, c_step;
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_BULK_COPY_H_
\ No newline at end of file
/*!
* \file tl/op/bulk_copy.cc
* \brief Bulk copy operator.
* \file tl/op/copy.cc
* \brief Define copy operator for various memory transfer strategies (Normal,
* Bulk/TMA, LDSM/STSM) and lowering logic for GPU code generation.
*
* This module is part of TVM TensorIR's Tensor Layout (TL) operations,
* implementing memory copy operations that can target CPUs or GPUs with
* optimization for different instructions like bulk copy, matrix load/store,
* and Hopper's new TMA (Tensor Memory Accelerator).
*/
#include "bulk_copy.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include "copy.h"
#include "../target/utils.h"
#include "../transform/common/loop_fusion_utils.h"
#include "../transform/common/loop_parallel_transform_utils.h"
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"
#include "../target/cuda.h"
#include "../target/utils.h"
#include "builtin.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/transform.h>
namespace tvm {
namespace tl {
using namespace tir;
/*!
* \brief Helper to map TVM's DataType to CUDA's CUtensorMapDataType enum value.
* This function converts TVM data types to CUDA tensor map data types for TMA
* operations.
*/
static int to_CUtensorMapDataType(DataType dtype) {
CUtensorMapDataType tp;
if (dtype.is_float()) {
......@@ -82,40 +97,669 @@ static int to_CUtensorMapDataType(DataType dtype) {
return static_cast<int>(tp);
}
/*!
* \brief Utility function to reverse an array.
* This is commonly used to convert between row-major and column-major layouts.
*/
template <typename T> static Array<T> ReverseArray(Array<T> array) {
return Array<T>{array.rbegin(), array.rend()};
}
Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (T.disable_tma_lower)
return Stmt();
if (!TargetIsHopper(T.target))
return Stmt();
bool is_load;
if (src.scope() == "global" &&
(dst.scope() == "shared.dyn" || dst.scope() == "shared")) {
is_load = true;
} else if (dst.scope() == "global" &&
(src.scope() == "shared.dyn" || src.scope() == "shared")) {
is_load = false;
/*!
* \brief Constructor for Copy operator.
* \param args Array of PrimExpr representing the arguments of the copy
* operation. \param vmap BufferMap mapping original buffer names to new buffer
* names.
*/
Copy::Copy(Array<PrimExpr> args, BufferMap vmap) : args_(args) {
Array<Range> rgs[2];
Buffer bf[2];
for (int i = 0; i < 2; i++) {
auto expr = args[i];
auto call = expr.as<CallNode>();
ICHECK(call);
auto region = RegionOp(call->args, vmap);
rgs[i] = region.GetRanges();
bf[i] = region.GetBuffer();
}
std::tie(this->src, this->dst) = std::tie(bf[0], bf[1]);
std::tie(this->src_range, this->dst_range) = std::tie(rgs[0], rgs[1]);
if (args.size() >= 3) {
auto coalesced_width = Downcast<IntImm>(args[2]);
if (coalesced_width->value > 0) {
this->coalesced_width = coalesced_width;
}
}
if (args.size() >= 4) {
this->disable_tma = Downcast<Bool>(args[3]);
}
if (args.size() >= 5) {
this->eviction_policy = args[4].as<IntImmNode>()->value;
}
}
/*!
* \brief Create iterator variables for the copy operation.
* This function creates iteration variables for dimensions that have extent
* > 1. \return Array of IterVar representing the iterator variables for the
* copy operation.
*/
Array<IterVar> Copy::MakeIterVars() const {
Array<IterVar> loop_vars;
size_t idx = 0;
for (size_t i = 0; i < src_range.size(); i++) {
if (is_one(src_range[i]->extent))
continue;
Var var = Var(std::string{char('i' + idx)}, src_range[i]->extent->dtype);
idx++;
loop_vars.push_back(
{Range(0, src_range[i]->extent), var, IterVarType::kDataPar});
}
return loop_vars;
}
/*!
* \brief Create indices for the copy operation.
* This function generates the actual index expressions for accessing source or
* destination buffers. For dimensions with extent=1, it uses the range minimum;
* for others, it adds the iteration variable. \param ivs Array of IterVar
* returned by MakeIterVars(). \param src_dst 0 for src_indices, 1 for
* dst_indices. \return Array of PrimExpr representing the indices for the copy
* operation.
*/
Array<PrimExpr> Copy::MakeIndices(const Array<IterVar> &ivs,
int src_dst) const {
Array<PrimExpr> indices;
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
size_t idx = 0;
for (size_t i = 0; i < ranges.size(); i++) {
if (is_one(ranges[i]->extent))
indices.push_back(ranges[i]->min);
else {
indices.push_back(ranges[i]->min + ivs[idx]->var);
idx++;
}
}
ICHECK(idx == ivs.size())
<< "idx = " << idx << ", ivs.size() = " << ivs.size()
<< "src name = " << src->name << ", dst name = " << dst->name;
return indices;
}
/*!
* \brief Create predicate for the copy operation.
* This function generates boundary checks to ensure memory access safety.
* It creates conditions like (min + iv) < extent and (min + iv) >= 0 for each
* dimension. \param analyzer Arithmetic analyzer for simplification. \param ivs
* Array of IterVar. \param extents Array of PrimExpr representing the extents
* of the copy operation. \param src_dst 0 for src_indices, 1 for dst_indices.
* \return PrimExpr representing the predicate for the copy operation.
*/
PrimExpr Copy::MakePredicate(arith::Analyzer *analyzer,
const Array<IterVar> &ivs, Array<PrimExpr> extents,
int src_dst) const {
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
Array<PrimExpr> cond_list;
ICHECK(extents.size() == ranges.size()) << extents << " " << ranges;
size_t idx = 0;
for (size_t i = 0; i < ranges.size(); i++) {
if (is_one(ranges[i]->extent))
continue;
PrimExpr cond = ranges[i]->min + ivs[idx]->var < extents[i];
if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
cond_list.push_back(cond);
}
cond = ranges[i]->min + ivs[idx]->var >= 0;
if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
cond_list.push_back(cond);
}
idx++;
}
if (cond_list.empty())
return {};
else {
PrimExpr cond = cond_list[0];
for (size_t i = 1; i < cond_list.size(); i++)
cond = And(cond, cond_list[i]);
return cond;
}
}
/*!
* \brief Create SIMT loop for the copy operation.
* This function generates a single-threaded loop structure for the copy
* operation. It handles scalar copies (single element) and multi-dimensional
* copies with nested loops. \param analyzer Arithmetic analyzer for
* simplification. \return For representing the SIMT loop for the copy
* operation.
*/
For Copy::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Array<IterVar> loop_vars = MakeIterVars();
bool is_scalar = loop_vars.size() == 0;
if (is_scalar) {
return For(Var("i"), 0, 1, ForKind::kSerial,
BufferStore(dst, BufferLoad(src, {0}), {0}));
}
for (const auto &iv : loop_vars)
analyzer->Bind(iv->var, iv->dom);
ICHECK(loop_vars.size() <= src_range.size())
<< "loop_vars.size() = " << loop_vars.size()
<< ", src_range.size() = " << src_range.size() << ", src = " << src->name
<< ", dst = " << dst->name;
ICHECK(loop_vars.size() <= dst_range.size())
<< "loop_vars.size() = " << loop_vars.size()
<< ", dst_range.size() = " << dst_range.size() << ", src = " << src->name
<< ", dst = " << dst->name;
Array<PrimExpr> src_indices = MakeIndices(loop_vars, 0);
Array<PrimExpr> dst_indices = MakeIndices(loop_vars, 1);
PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);
PrimExpr value = BufferLoad(src, src_indices);
if (src->dtype != dst->dtype)
value = Cast(dst->dtype, value);
if (src_predicate.defined())
value = if_then_else(src_predicate, value, make_zero(dst->dtype));
Stmt body = BufferStore(dst, value, dst_indices);
if (dst_predicate.defined())
body = IfThenElse(dst_predicate, body);
for (int i = loop_vars.size() - 1; i >= 0; i--) {
Map<String, ObjectRef> annotations = {};
if (coalesced_width.defined()) {
annotations.Set("coalesced_width", coalesced_width);
}
body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
ForKind::kParallel, body, std::nullopt, annotations);
}
return Downcast<For>(body);
}
/*!
* \brief Compute linear layout for TMA copy.
* This function creates a linear layout transformation for shared memory in TMA
* operations. It transforms multi-dimensional indices into a linear address
* using a 256-element block pattern. The transformation follows: [i, j] ->
* [i//256, j//256, i%256, j%256] \param shared_tensor Buffer representing the
* shared tensor. \return Layout representing the linear layout for the TMA
* copy.
*/
Layout Copy::ComputeLinearLayout(const Buffer &shared_tensor) const {
Array<PrimExpr> input_size = shared_tensor->shape;
Array<PrimExpr> forward_vars;
for (size_t i = 0; i < input_size.size(); i++) {
forward_vars.push_back(InputPlaceholder(i));
}
// [i, j] -> [i // 256, j // 256, i % 256, j % 256]
Array<PrimExpr> forward_index;
for (size_t i = 0; i < input_size.size(); i++) {
forward_index.push_back(FloorDiv(forward_vars[i], 256));
}
for (size_t i = 0; i < input_size.size(); i++) {
forward_index.push_back(FloorMod(forward_vars[i], 256));
}
return Layout(input_size, forward_index);
}
/*!
* \brief Infer layout for the copy operation.
* This function determines the optimal memory layout for the copy operation
* based on the target architecture. For bulk load/store operations, it may
* apply swizzling layouts for better performance. For LDSM/STSM operations, it
* uses register layout inference from the underlying parallel op. \param T
* LayoutInferArgs containing target and layout map. \param level InferLevel
* indicating the level of layout inference. \return LayoutMap containing the
* inferred layout.
*/
LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) {
auto target = T.target;
using namespace tvm::transform;
PassContext pass_ctx = PassContext::Current();
bool disable_tma_lower =
pass_ctx->GetConfig<bool>(kDisableTMALower, false).value();
auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma);
if (copy_inst == CopyInst::kBulkLoad || copy_inst == CopyInst::kBulkStore) {
// if can apply swizzling, we skip layout inference
// for bulk load/store, we can directly apply the layout of normal copy
// This must be a global/shared layout, so we can skip the parallel op
// layout inference (parallel layout inference only annotate the loop layout
// and the register layout).
bool is_load = copy_inst == CopyInst::kBulkLoad;
Buffer global_tensor = is_load ? src : dst;
Buffer shared_tensor = is_load ? dst : src;
// check shared layout is non-swizzle
// skip layout inference if shared layout is already annotated
if (level == InferLevel::kFree && !T.layout_map.count(shared_tensor)) {
// create a new layout map for tma linear layout
Layout linear_layout = ComputeLinearLayout(shared_tensor);
return Map<Buffer, Layout>({{shared_tensor, linear_layout}});
}
}
// for LDSM/STSM, the layout was deduced from register layout
// so we can directly apply the layout of normal copy
// Use parallel op to infer the layout
if (!par_op_) {
arith::Analyzer analyzer;
par_op_ = std::make_unique<ParallelOp>(MakeSIMTLoop(&analyzer));
}
return par_op_->InferLayout(T, level);
}
/*!
* \brief Check if the copy operation is a bulk load.
* This function verifies if the copy operation can be implemented using CUDA's
* Bulk Load instruction. Requirements include: target supports bulk copy,
* source is global memory, destination is shared.dyn, and both buffers have the
* same data type. \param target Target device. \return True if the copy
* operation is a bulk load, false otherwise.
*/
bool Copy::CheckBulkLoad(Target target) const {
// 1. arch must have bulk copy support
if (!TargetHasBulkCopy(target))
return false;
// 2. src and dst must be global and shared
if (src.scope() != "global" ||
(dst.scope() != "shared.dyn" && dst.scope() != "shared"))
return false;
// 3. check shape.
// TODO(lei): validate if we can utilize tma under this shape.
// 4. src and dst must have the same dtype
if (src->dtype != dst->dtype) {
LOG(WARNING) << "src and dst must have the same dtype for tma load "
<< src->name << " vs. " << dst->name << " dtype " << src->dtype
<< " vs. " << dst->dtype << " will be fallback to normal copy";
return false;
}
return true;
}
/*!
* \brief Check if the copy operation is a bulk store.
* This function verifies if the copy operation can be implemented using CUDA's
* Bulk Store instruction. Requirements include: target supports bulk copy,
* source is shared.dyn, destination is global memory, and both buffers have the
* same data type. \param target Target device. \return True if the copy
* operation is a bulk store, false otherwise.
*/
bool Copy::CheckBulkStore(Target target) const {
// 1. arch must have bulk copy support
if (!TargetHasBulkCopy(target))
return false;
// 2. src and dst must be shared.dyn and local.fragment
if ((src.scope() != "shared.dyn" && src.scope() != "shared") ||
dst.scope() != "global")
return false;
// 3. check shape.
// TODO(lei): validate if we can utilize tma under this shape.
// 4. src and dst must have the same dtype
if (src->dtype != dst->dtype) {
LOG(WARNING) << "src and dst must have the same dtype for tma store "
<< src->name << " vs. " << dst->name << " dtype " << src->dtype
<< " vs. " << dst->dtype << " will be fallback to normal copy";
return false;
}
return true;
}
/*!
* \brief Check if the copy operation is a LDSM copy.
* This function verifies if the copy operation can be implemented using CUDA's
* Load Matrix (LDSM) instruction. Requirements include: target supports
* LDMATRIX, source is shared.dyn, destination is local.fragment. \param target
* Target device. \return True if the copy operation is a LDSM copy, false
* otherwise.
*/
bool Copy::CheckLDSMCopy(Target target) const {
return TargetHasLdmatrix(target) &&
(src.scope() == "shared.dyn" || src.scope() == "shared") &&
dst.scope() == "local.fragment";
}
/*!
* \brief Check if the copy operation is a STSM copy.
* This function verifies if the copy operation can be implemented using CUDA's
* Store Matrix (STSM) instruction. Requirements include: target supports
* STMATRIX, source is local.fragment, destination is shared.dyn. \param target
* Target device. \return True if the copy operation is a STSM copy, false
* otherwise.
*/
bool Copy::CheckSTSMCopy(Target target) const {
return TargetHasStmatrix(target) && src.scope() == "local.fragment" &&
(dst.scope() == "shared.dyn" || dst.scope() == "shared");
}
/*!
* \brief Get the copy instruction type.
* This function determines the most appropriate copy instruction based on the
* target architecture and buffer memory scopes. It checks for specialized
* instructions (TMA, LDSM, STSM) in order of preference, falling back to normal
* copy if no specialized instruction is applicable. \param target Target
* device. \return CopyInst representing the copy instruction type.
*/
Copy::CopyInst Copy::GetCopyInst(Target target, bool disable_tma_lower) const {
// disable_tma_lower is from pass_configs
// when tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER is True,
// we will not use tma for bulk load/store
if (!disable_tma_lower && CheckBulkLoad(target)) {
return CopyInst::kBulkLoad;
} else if (!disable_tma_lower && CheckBulkStore(target)) {
return CopyInst::kBulkStore;
} else if (CheckLDSMCopy(target)) {
return CopyInst::kLDSM;
} else if (CheckSTSMCopy(target)) {
return CopyInst::kSTSM;
} else {
return CopyInst::kNormal;
}
}
/*!
* \brief Lower the copy operation to PTX code.
* This function converts the high-level copy operation into low-level PTX
* instructions. It dispatches to specialized lowering functions based on the
* determined copy instruction type:
* - Bulk Load/Store: Uses Tensor Memory Accelerator (TMA) instructions
* - LDSM/STSM: Uses matrix load/store instructions for tensor cores
* - Normal: Uses standard load/store operations with loop transformations
* \param T LowerArgs containing target and layout map.
* \param analyzer Arithmetic analyzer for simplification.
* \return Stmt representing the PTX code for the copy operation.
*/
Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Target target = T.target;
using namespace tvm::transform;
PassContext pass_ctx = PassContext::Current();
bool disable_tma_lower =
pass_ctx->GetConfig<bool>(kDisableTMALower, false).value();
auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma);
if (copy_inst == CopyInst::kBulkLoad || copy_inst == CopyInst::kBulkStore) {
auto bulk_copy = LowerBulkCopy(T, analyzer, copy_inst);
ICHECK(bulk_copy.defined()) << "Failed to lower bulk copy";
return bulk_copy;
} else if (copy_inst == CopyInst::kLDSM || copy_inst == CopyInst::kSTSM) {
auto ldsm_copy = LowerLDSMCopy(T, analyzer, copy_inst);
ICHECK(ldsm_copy.defined()) << "Failed to lower ptx matrix copy";
return ldsm_copy;
} else if (copy_inst == CopyInst::kNormal) {
return LowerNormalCopy(T, analyzer);
} else {
LOG(FATAL) << "Unsupported copy inst " << static_cast<int>(copy_inst);
}
}
/*!
* \brief Lower the copy operation to a normal copy.
* This function generates standard load/store operations for targets that don't
* support specialized copy instructions. It applies loop fusion,
* parallelization, and vectorization transformations to optimize performance on
* both CPU and GPU targets. \param T LowerArgs containing target and layout
* map. \param analyzer Arithmetic analyzer for simplification. \return Stmt
* representing the normal copy code.
*/
Stmt Copy::LowerNormalCopy(const LowerArgs &T,
arith::Analyzer *analyzer) const {
bool is_cpu_target = T.target->GetTargetDeviceType() == kDLCPU;
auto simt_loop = MakeSIMTLoop(analyzer);
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
auto transformed_loop =
Downcast<For>(ParallelLoopTransformer::Substitute(fused_loop));
For vectorized_thread_loop;
auto par_op = std::make_unique<ParallelOp>(transformed_loop);
if (is_cpu_target) {
vectorized_thread_loop = VectorizeLoop(transformed_loop);
} else {
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
InferLevel::kFree};
for (auto level : levels) {
par_op->InferLayout(
{T.target, T.thread_bounds, T.layout_map, T.buffer_remap}, level);
}
auto loop_layout = par_op->GetLoopLayout();
auto thread_var = T.thread_var;
auto thread_loop =
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
vectorized_thread_loop = VectorizeLoop(thread_loop);
}
if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
vectorized_thread_loop);
}
return vectorized_thread_loop;
}
/*!
* \brief Lower the copy operation to LDSM/STSM copy.
* This function generates PTX code for matrix load/store operations
* (LDSM/STSM). It handles 8x8 fragment layout validation, shared memory stride
* checking, and generates optimized matrix transfer instructions for tensor
* cores. Falls back to normal copy if layout constraints are not satisfied.
* \param T LowerArgs containing target and layout map.
* \param analyzer Arithmetic analyzer for simplification.
* \param copy_inst CopyInst representing the copy instruction type.
* \return Stmt representing the LDSM/STSM copy code.
*/
Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer,
CopyInst copy_inst) const {
ICHECK(copy_inst == CopyInst::kLDSM || copy_inst == CopyInst::kSTSM)
<< "Invalid copy inst " << static_cast<int>(copy_inst);
bool is_ldmatrix = copy_inst == CopyInst::kLDSM;
// Check no predicates
Array<IterVar> loop_vars = MakeIterVars();
if (loop_vars.size() < 2) {
// cannot support 1-d case
return LowerNormalCopy(T, analyzer);
}
for (const auto &iv : loop_vars)
analyzer->Bind(iv->var, iv->dom);
PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);
if (src_predicate.defined() || dst_predicate.defined()) {
// stmatrix and ldmatrix can only support no predicate
return LowerNormalCopy(T, analyzer);
}
Buffer shared_tensor = is_ldmatrix ? src : dst;
Buffer local_tensor = is_ldmatrix ? dst : src;
Array<PrimExpr> local_indices = MakeIndices(loop_vars, is_ldmatrix ? 1 : 0);
Fragment local_layout = Downcast<Fragment>(T.layout_map[local_tensor]);
Array<PrimExpr> local_indices_transformed =
local_layout->Forward(local_indices);
local_tensor = T.buffer_remap[local_tensor];
// currently only support 1-d case
if (local_layout->OutputDim() != 1) {
// TMA ldmatrix/stmatrix cannot support non-1-d layout, will be fallback to
// normal copy
return LowerNormalCopy(T, analyzer);
}
Array<PrimExpr> shared_indices = MakeIndices(loop_vars, is_ldmatrix ? 0 : 1);
Array<PrimExpr> shared_indices_transformed = shared_indices;
Layout shared_layout;
if (T.buffer_remap.count(shared_tensor)) {
shared_layout = T.layout_map[shared_tensor];
shared_tensor = T.buffer_remap[shared_tensor];
shared_indices_transformed = shared_layout->Forward(shared_indices);
}
// Check local_layout follows 8x8 layout
// LDSM/STSM instructions require 8x8 matrix fragment layout
// This matches the warp-level matrix multiplication pattern used in tensor
// cores We check both normal and transposed layouts to support different
// access patterns
bool is_transposed;
IterVar col_var = loop_vars[loop_vars.size() - 1];
IterVar row_var = loop_vars[loop_vars.size() - 2];
PrimExpr local_layout_thread_map =
FloorMod(local_layout->ForwardThread(local_indices, std::nullopt), 32);
PrimExpr matrix_8x8_thread_map = makeGemmFragment8x8()->ForwardThread(
{FloorMod(row_var, 8), FloorMod(col_var, 8)}, std::nullopt);
PrimExpr matrix_8x8_thread_map_trans =
makeGemmFragment8x8Transposed()->ForwardThread(
{FloorMod(row_var, 8), FloorMod(col_var, 8)}, std::nullopt);
PrimExpr local_indices_flattened =
local_tensor.OffsetOf(local_indices_transformed).back();
if (analyzer->CanProveEqual(matrix_8x8_thread_map, local_layout_thread_map) &&
IndiceCanVectorize(local_indices_flattened, col_var->var,
col_var->dom->extent, 2, analyzer)) {
is_transposed = false;
} else if (analyzer->CanProveEqual(matrix_8x8_thread_map_trans,
local_layout_thread_map) &&
IndiceCanVectorize(local_indices_flattened, row_var->var,
row_var->dom->extent, 2, analyzer)) {
is_transposed = true;
} else {
return Stmt();
// TMA ldmatrix/stmatrix cannot support non-8x8 layout, will be fallback to
// normal copy
return LowerNormalCopy(T, analyzer);
}
// Check shared_layout is 16 bytes continuous
// LDSM/STSM instructions require 16-byte aligned data (half-precision floats)
// This is a hardware constraint for matrix load/store operations
if (shared_tensor->dtype.bytes() != 2) {
// TMA ldmatrix/stmatrix cannot support non-16 bytes continuous layout, will
// be fallback to normal copy
return LowerNormalCopy(T, analyzer);
}
PrimExpr flattened_indice =
shared_tensor.OffsetOf(shared_indices_transformed).back();
if (!IndiceCanVectorize(flattened_indice, loop_vars.back()->var,
loop_vars.back()->dom->extent, 8, analyzer)) {
// TMA ldmatrix/stmatrix cannot support non-16 bytes continuous layout, will
// be fallback to normal copy
return LowerNormalCopy(T, analyzer);
}
// Can only support local_range to be a full range
for (size_t i = 0; i < dst_range.size(); i++) {
if (!is_zero(dst_range[i]->min) ||
!analyzer->CanProveEqual(dst_range[i]->extent, dst->shape[i]))
// TMA ldmatrix/stmatrix cannot support non-full range, will be fallback
// to normal copy
return LowerNormalCopy(T, analyzer);
}
// Do the lowering here, try vectorized ldmatrix/stmatrix by 4/2/1
PrimExpr extent = local_tensor->shape[0];
int num = 1;
if (analyzer->CanProveEqual(FloorMod(extent, 8), 0))
num = 4;
else if (analyzer->CanProveEqual(FloorMod(extent, 4), 0))
num = 2;
Array<PrimExpr> args;
const Op &op = is_ldmatrix ? tl::ptx_ldmatrix() : tl::ptx_stmatrix();
args.push_back(static_cast<int>(is_transposed));
args.push_back(num);
// Create shared address with regard to local address
// if not transpose
// coords = Inverse(base + 2 * (thread / 8) % num, warp + (thread % 8) * 4))
// if transpose
// coords = Inverse(base + 2 * (thread / 8) % num + thread % 2, warp + thread
// % 8 / 2)
Var local_iter("i");
Layout inv = local_layout->Inverse();
Array<PrimExpr> shared_coords;
PrimExpr warp = FloorDiv(T.thread_var, 32) * 32;
if (!is_transposed)
shared_coords = inv->Forward(
{local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num),
warp + FloorMod(T.thread_var, 8) * 4});
else
shared_coords = inv->Forward(
{local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num) +
FloorMod(T.thread_var, 2),
warp + FloorDiv(FloorMod(T.thread_var, 8), 2)});
shared_coords.pop_back(); // remove rep
if (shared_layout.defined())
shared_coords = shared_layout->Forward(shared_coords);
PrimExpr shared_addr = shared_tensor.access_ptr(
is_ldmatrix ? 1 : 2, DataType::Handle(), 1,
shared_tensor.OffsetOf(shared_coords).back(), PrimExpr(2 * num));
args.push_back(shared_addr);
if (is_ldmatrix) {
// Can only support same dtype for ldmatrx
if (local_tensor->dtype != shared_tensor->dtype) {
// TMA ldmatrix cannot support different dtype, will be fallback to normal
// copy
return LowerNormalCopy(T, analyzer);
}
PrimExpr local_addr = local_tensor.access_ptr(
2, DataType::Handle(), 1, local_iter * 2 * num, PrimExpr(2 * num));
args.push_back(local_addr);
} else {
for (int i = 0; i < num; i++) {
PrimExpr value0 =
BufferLoad(local_tensor, {local_iter * 2 * num + 2 * i});
PrimExpr value1 =
BufferLoad(local_tensor, {local_iter * 2 * num + 2 * i + 1});
if (local_tensor->dtype != shared_tensor->dtype) {
value0 = Cast(shared_tensor->dtype, value0);
value1 = Cast(shared_tensor->dtype, value1);
}
PrimExpr value_packed =
Call(DataType::Int(32), pack_b16(), {value0, value1});
args.push_back(value_packed);
}
}
auto body = Evaluate(Call(DataType::Handle(), op, args));
For for_node =
For(local_iter, 0, FloorDiv(extent, 2 * num), ForKind::kSerial, body);
for_node = LoopPragmaUnroll(for_node);
auto range = T.thread_bounds;
if (range.defined()) {
auto thread_var = T.thread_var;
auto thread_var_with_offset = thread_var - range->min;
for_node.CopyOnWrite()->body =
Substitute(for_node->body, {{thread_var, thread_var_with_offset}});
}
return for_node;
}
/*!
* \brief Lower the copy operation to bulk copy using TMA.
* This function generates PTX code for Tensor Memory Accelerator (TMA) bulk
* copy operations. It creates TMA descriptors, handles shared memory layout
* detection (including swizzling), and generates optimized bulk load/store
* instructions for Hopper architecture. Falls back to normal copy if layout or
* shape constraints are not satisfied. \param T LowerArgs containing target and
* layout map. \param analyzer Arithmetic analyzer for simplification. \param
* copy_inst CopyInst representing the copy instruction type. \return Stmt
* representing the bulk copy code.
*/
Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
CopyInst copy_inst) const {
ICHECK(copy_inst == CopyInst::kBulkLoad || copy_inst == CopyInst::kBulkStore)
<< "Invalid copy inst " << static_cast<int>(copy_inst);
bool is_load = copy_inst == CopyInst::kBulkLoad;
Buffer global_tensor = is_load ? src : dst;
Buffer shared_tensor = is_load ? dst : src;
Array<Range> global_range = is_load ? src_range : dst_range;
Array<Range> shared_range = is_load ? dst_range : src_range;
// TMA bulk copy cannot support a non-swizzled global layout, will be fallback
// to normal copy
if (T.layout_map.count(global_tensor)) {
LOG(WARNING) << "TMA bulk copy cannot support a non-swizzled global "
"layout, fallback to normal copy.";
return Stmt();
return LowerNormalCopy(T, analyzer);
}
if (T.layout_map.count(global_tensor)) {
LOG(WARNING) << "TMA bulk copy cannot support a non-swizzled global "
"layout, fallback to normal copy.";
return Stmt();
}
// linear layout must be computed before remapping
auto linear_layout = ComputeLinearLayout(shared_tensor);
Array<PrimExpr> indices;
for (auto r : shared_range)
......@@ -184,7 +828,7 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (stride->value % 16 != 0 || stride->value >= (1ULL << 40)) {
LOG(WARNING) << "TMA bulk copy cannot support a global stride of "
<< desc.global_stride[i] << ", fallback to normal copy.";
return Stmt();
return LowerNormalCopy(T, analyzer);
}
}
}
......@@ -226,22 +870,25 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
desc.oob_fill = static_cast<int>(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
// Detect smem layout
// Shared memory swizzling is crucial for TMA performance
// It determines how data is arranged in shared memory banks to minimize bank
// conflicts Different swizzle patterns (32B, 64B, 128B) offer different
// trade-offs between access efficiency and memory usage
desc.interleave = static_cast<int>(CU_TENSOR_MAP_INTERLEAVE_NONE);
if (!shared_layout.defined()) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE);
} else if (StructuralEqual()(shared_layout, linear_layout)) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE);
} else {
ICHECK(shared_layout->InputDim() == 2) << "Cannot detect TMA layout.";
auto stride = as_const_int(shared_layout->InputShape()[0]);
auto continuous = as_const_int(shared_layout->InputShape()[1]);
ICHECK(stride != nullptr && continuous != nullptr);
if (StructuralEqual()(shared_layout, makeGemmABLayoutPadded(
// We also need to check if the shape satisfies the following doc:
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
if (StructuralEqual()(shared_layout, makeQuarterBankSwizzleLayout(
*stride, *continuous,
shared_tensor->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE);
} else if (StructuralEqual()(
shared_layout,
makeQuarterBankSwizzleLayout(*stride, *continuous,
shared_tensor->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_32B);
} else if (StructuralEqual()(
shared_layout,
......@@ -253,8 +900,19 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
makeFullBankSwizzleLayout(*stride, *continuous,
shared_tensor->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B);
} else if (StructuralEqual()(
shared_layout,
makeGemmABLayoutPadded(*stride, *continuous,
shared_tensor->dtype.bits()))) {
LOG(WARNING) << "Bulk copy cannot support a padded layout for src: "
<< src->name << ", dst: " << dst->name
<< ", fallback to normal copy";
return LowerNormalCopy(T, analyzer);
} else {
return Stmt();
LOG(WARNING) << "Came across unsupported swizzle layout for src: "
<< src->name << ", dst: " << dst->name
<< ", fallback to normal copy";
return LowerNormalCopy(T, analyzer);
}
}
......@@ -280,18 +938,24 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
int inner_box_dim_ = instruction_dim * shared_tensor->dtype.bytes();
if (desc.swizzle == static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE) &&
inner_box_dim_ % 256 != 0)
return Stmt();
#define CHECK_INNER_BOX_DIM(N) \
if (desc.swizzle == static_cast<int>(CU_TENSOR_MAP_SWIZZLE_##N##B) && \
inner_box_dim_ > N) \
return Stmt();
CHECK_INNER_BOX_DIM(32);
CHECK_INNER_BOX_DIM(64);
CHECK_INNER_BOX_DIM(128);
#undef CHECK_INNER_BOX_DIM
// Check inner_box_dim_ for each swizzle type in a cleaner way
struct SwizzleCheck {
int swizzle;
int max_dim;
};
static const SwizzleCheck swizzle_checks[] = {
{static_cast<int>(CU_TENSOR_MAP_SWIZZLE_32B), 32},
{static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B), 64},
{static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B), 128},
};
for (const auto &check : swizzle_checks) {
if (desc.swizzle == check.swizzle && inner_box_dim_ > check.max_dim) {
LOG(WARNING) << "TMA bulk copy cannot support a swizzled global layout "
"with inner_box_dim_ > "
<< check.max_dim << ", will be fallback to normal copy";
return LowerNormalCopy(T, analyzer);
}
}
Call create_descriptor =
Call(DataType::Handle(), create_tma_descriptor(), desc.EncodeCallArgs());
......@@ -336,6 +1000,13 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
return tma_copy;
}
/*!
* \brief Encode the TMA descriptor into an array of PrimExpr.
* This function serializes the TMA descriptor fields into a format suitable for
* passing to the create_tma_descriptor() builtin function. The encoding follows
* the expected argument order for the TMA descriptor creation.
* \return Array of PrimExpr representing the encoded TMA descriptor.
*/
Array<PrimExpr> TMADesc::EncodeCallArgs() const {
Array<PrimExpr> args;
args.reserve(rank * 4 + 7);
......@@ -359,8 +1030,14 @@ Array<PrimExpr> TMADesc::EncodeCallArgs() const {
return args;
}
DataType cuTensorMapType() { return DataType::UInt(8, 128); }
/*!
* \brief Constructor for Conv2DIm2ColOp.
* This operation performs im2col transformation for 2D convolution on GPU using
* TMA. It extracts patches from the input tensor and rearranges them for
* efficient matrix multiplication. \param args Array of PrimExpr representing
* the arguments of the Conv2DIm2ColOp. \param vmap BufferMap mapping original
* buffer names to new buffer names.
*/
Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) {
src = vmap[GetVarFromAccessPtr(args[0])];
dst = vmap[GetVarFromAccessPtr(args[1])];
......@@ -373,6 +1050,16 @@ Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) {
eviction_policy = args[8].as<IntImm>().value()->value;
}
/*!
* \brief Lower the Conv2DIm2ColOp to PTX code.
* This function generates optimized im2col transformation using TMA
* instructions. It creates a TMA descriptor for the im2col operation, handling
* convolution parameters like kernel size, stride, padding, and dilation. The
* operation is optimized for Hopper architecture with support for different
* shared memory layouts. \param T LowerArgs containing target and layout map.
* \param analyzer Arithmetic analyzer for simplification.
* \return Stmt representing the PTX code for the Conv2DIm2ColOp.
*/
Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T,
arith::Analyzer *analyzer) const {
ICHECK(TargetIsHopper(T.target));
......@@ -497,6 +1184,14 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T,
return tma_copy;
}
/*!
* \brief Encode the TMA im2col descriptor into an array of PrimExpr.
* This function serializes the TMA im2col descriptor fields for passing to the
* create_tma_im2col_descriptor() builtin function. It includes
* convolution-specific parameters like kernel size, stride, padding, and
* dilation in addition to standard tensor descriptor fields. \return Array of
* PrimExpr representing the encoded TMA im2col descriptor.
*/
Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const {
Array<PrimExpr> args;
args.reserve(rank * 5 + 5);
......@@ -524,10 +1219,24 @@ Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const {
return args;
}
// Register the Copy operation with TVM's TIR system
// This makes the copy operation available for use in TVM programs
// - Takes 4 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma
// - Marked as opaque since it has side effects (memory writes)
TIR_REGISTER_TL_OP(Copy, copy)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
// Register the Conv2DIm2Col operation with TVM's TIR system
// This operation performs im2col transformation for 2D convolutions using TMA
// - Takes 9 inputs: src_buffer, dst_buffer, nhw_step, c_step, kernel, stride,
// dilation, padding, eviction_policy
// - Marked as opaque since it has side effects (memory writes)
TIR_REGISTER_TL_OP(Conv2DIm2ColOp, c2d_im2col)
.set_num_inputs(9)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl
} // namespace tvm
} // namespace tvm
\ No newline at end of file
/*!
* \file tl/op/elem.h
* \brief Define element-wise and copy-related operators for TVM TensorIR
* Lowering.
*
* This header declares the Copy operator and related operator descriptors
* such as TMADesc and TMAIm2ColDesc, as well as a Conv2DIm2Col special
* operator.
*/
#ifndef TVM_TL_OP_COPY_H_
#define TVM_TL_OP_COPY_H_
#include "op.h"
#include "parallel.h"
namespace tvm {
namespace tl {
using namespace tir;
/*!
* \brief Descriptor for Tensor Memory Access (TMA) copy operations.
*
* Contains meta-information required to perform global-to-shared memory copy
* using Tensor Memory Accelerator (TMA) hardware instructions. It is mainly
* used to describe the shape, strides, and data layout for both source and
* shared memory buffers.
*/
struct TMADesc {
size_t rank; // Tensor rank (number of dimensions)
int data_type; // Data type identifier (numeric code)
Array<PrimExpr> global_shape; // Shape of the source tensor in global memory
Array<PrimExpr>
global_stride; // Strides of the source tensor in global memory
Array<PrimExpr> smem_box; // Block shape in shared memory
Array<PrimExpr> smem_stride; // Strides in shared memory layout
PrimExpr global_addr; // Base address in global memory
int swizzle; // Swizzle parameter for memory layout transform
int interleave; // Interleave parameter for optimization
int oob_fill; // Out-of-bound fill policy
int l2_promotion; // Whether to promote data to L2 cache
/*!
* \brief Encode descriptor fields into an argument array for runtime calls.
*/
Array<PrimExpr> EncodeCallArgs() const;
};
/*!
* \brief Descriptor for TMA-based im2col transformation used in Conv2D.
*
* This supports extracting patches from the input image (im2col)
* for convolution lowering, storing them in shared memory.
*/
struct TMAIm2ColDesc {
size_t rank; // Rank of the tensor
int data_type; // Data type identifier
Array<PrimExpr> global_shape; // Shape of input tensor in global memory
Array<PrimExpr> global_stride; // Stride in global memory
Array<PrimExpr> elem_stride; // Stride at element level (per axis)
Array<PrimExpr> lower_corner; // Lower bound offsets for the extraction window
// (rank - 2 dims)
Array<PrimExpr> upper_corner; // Upper bound offsets for the extraction window
// (rank - 2 dims)
PrimExpr global_addr; // Base address in global memory
int smem_box_pixel; // Pixel dimension of shared memory box
int smem_box_channel; // Channel dimension of shared memory box
int swizzle; // Memory swizzle setting
int interleave; // Memory interleaving setting
int oob_fill; // Out-of-bound fill policy
int l2_promotion; // Whether to enable L2 cache promotion
/*!
* \brief Encode descriptor fields into runtime arguments.
*/
Array<PrimExpr> EncodeCallArgs() const;
};
/*!
* \brief Copy operator for transferring data between buffers.
*
* This class implements a generic copy operator in TensorIR Lowering for
* block-wise or element-wise data transfer, possibly optimized with
* parallelization or TMA hardware acceleration.
*/
class Copy : public Operator {
public:
/*!
* \brief Constructor.
* \param args Expression arguments for the copy.
* \param vmap Buffer variable mapping.
*/
Copy(Array<PrimExpr> args, BufferMap vmap);
/*!
* \brief Lower the copy operator to a TIR statement.
* \param T Arguments for lowering.
* \param analyzer Analyzer for simplification and bounds checks.
*/
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
/*!
* \brief Infer buffer layouts after applying this operator.
* \param T Arguments for layout inference.
* \param level Level of inference (basic or detailed).
*/
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
/*!
* \brief Get the TVM Op handle corresponding to this Copy op.
*/
static const Op &Get();
/*!
* \brief Copy instruction type.
*/
enum class CopyInst {
kNormal = 0, // utilize ldg/stg or cpasync or any buffer copy
kLDSM = 1, // ldmatrix memory copy
kSTSM = 2, // stmatrix memory copy
kBulkLoad = 3, // utilize tma load
kBulkStore = 4, // utilize tma store
};
/*!
* \brief Check if bulk copy is supported.
*/
bool CheckBulkLoad(Target target) const;
/*!
* \brief Check if bulk store is supported.
*/
bool CheckBulkStore(Target target) const;
/*!
* \brief Check if lds memory copy is supported.
*/
bool CheckLDSMCopy(Target target) const;
/*!
* \brief Check if stsm memory copy is supported.
*/
bool CheckSTSMCopy(Target target) const;
/*!
* \brief Get the copy instruction type.
*/
CopyInst GetCopyInst(Target target, bool disable_tma_lower) const;
/*!
* \brief Copy constructor (deep clones ParallelOp if present).
*/
Copy(const Copy &other)
: args_(other.args_), src(other.src), dst(other.dst),
src_range(other.src_range), dst_range(other.dst_range),
coalesced_width(other.coalesced_width), disable_tma(other.disable_tma) {
// Deep copy ParallelOp if it exists
if (other.par_op_)
par_op_ = std::unique_ptr<ParallelOp>(
static_cast<ParallelOp *>(other.par_op_->Clone().release()));
}
/*!
* \brief Clone this copy operator.
*/
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<Copy>(*this);
}
protected:
/*!
* \brief Generate lowering for bulk/global-to-shared copy.
*/
Stmt LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
CopyInst copy_inst) const;
/*!
* \brief Generate lowering for LDS Memory Copy (shared memory to shared
* memory or smem usage).
*/
Stmt LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer,
CopyInst copy_inst) const;
/*!
* \brief Generate lowering for normal copy.
*/
Stmt LowerNormalCopy(const LowerArgs &T, arith::Analyzer *analyzer) const;
/*!
* \brief Generate SIMT (thread-level) loop for copying.
*/
For MakeSIMTLoop(arith::Analyzer *analyzer) const;
/*!
* \brief Compute linear layout for tma copy.
*/
Layout ComputeLinearLayout(const Buffer &shared_tensor) const;
/*!
* \brief Create iterator variables for multi-dimensional copy loops.
*/
Array<IterVar> MakeIterVars() const;
/*!
* \brief Calculate source or destination indices from iteration vars.
* \param ivs Iterator variables from MakeIterVars().
* \param src_dst 0 = make source indices, 1 = make destination indices.
*/
Array<PrimExpr> MakeIndices(const Array<IterVar> &ivs, int src_dst) const;
/*!
* \brief Construct the boundary predicate for valid copy (to avoid OOB).
* \param analyzer Arithmetic analyser for simplification.
* \param ivs Iterator variables.
* \param extents Extent expressions for the relevant buffer.
* \param src_dst 0 = predicate for source, 1 = predicate for destination.
*/
PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs,
Array<PrimExpr> extents, int src_dst) const;
Array<PrimExpr> args_; // Copy parameters (indices, sizes, etc.)
Buffer src, dst; // Source and destination buffers
Array<Range> src_range, dst_range; // Ranges for each dimension in src and dst
IntImm coalesced_width; // Width (in elements) for coalesced memory access
Bool disable_tma = Bool(false); // Whether to disable TMA acceleration
std::unique_ptr<ParallelOp>
par_op_; // Optional associated parallelization operator
enum class EvictionPolicy {
kEvictNormal = 0,
kEvictFirst = 1,
kEvictLast = 2,
};
int eviction_policy; // Policy for cache eviction
};
/*!
* \brief Special operator for Conv2D im2col transformation.
*
* This operator converts input image layout into columnar format suitable
* for matrix multiplication-based convolution lowering.
*/
class Conv2DIm2ColOp : public Operator {
public:
/*!
* \brief Constructor.
* \param args Op arguments (convolution parameters, shapes, etc.)
* \param vmap Variable buffer mapping.
*/
Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap);
/*!
* \brief Lower to TIR statement.
*/
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
/*!
* \brief Get TVM Op handle.
*/
static const Op &Get();
/*!
* \brief Clone this operator.
*/
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<Conv2DIm2ColOp>(*this);
}
private:
Buffer src, dst; // Source (input feature map) and destination (im2col matrix)
int stride; // Stride for convolution
int padding; // Padding amount
int dilation; // Dilation factor
int kernel; // Kernel size
int eviction_policy; // Cache eviction policy
PrimExpr nhw_step; // Step size in NHW dimensions
PrimExpr c_step; // Step size in channel dimension
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_COPY_H_
\ No newline at end of file
......@@ -22,363 +22,6 @@ namespace tl {
using namespace tir;
Copy::Copy(Array<PrimExpr> args, BufferMap vmap) : args_(args) {
Array<Range> rgs[2];
Buffer bf[2];
for (int i = 0; i < 2; i++) {
auto expr = args[i];
auto call = expr.as<CallNode>();
ICHECK(call);
auto region = RegionOp(call->args, vmap);
rgs[i] = region.GetRanges();
bf[i] = region.GetBuffer();
}
std::tie(this->src, this->dst) = std::tie(bf[0], bf[1]);
std::tie(this->src_range, this->dst_range) = std::tie(rgs[0], rgs[1]);
if (args.size() >= 3) {
auto coalesced_width = Downcast<IntImm>(args[2]);
if (coalesced_width->value > 0) {
this->coalesced_width = coalesced_width;
}
}
if (args.size() >= 4) {
auto disable_tma = Downcast<Bool>(args[3]);
this->disable_tma = disable_tma;
}
if (args.size() >= 5) {
this->eviction_policy = args[4].as<IntImmNode>()->value;
}
}
Array<IterVar> Copy::MakeIterVars() const {
Array<IterVar> loop_vars;
size_t idx = 0;
for (size_t i = 0; i < src_range.size(); i++) {
if (is_one(src_range[i]->extent))
continue;
Var var = Var(std::string{char('i' + idx)}, src_range[i]->extent->dtype);
idx++;
loop_vars.push_back(
{Range(0, src_range[i]->extent), var, IterVarType::kDataPar});
}
return loop_vars;
}
// ivs: itervars returned by MakeIterVars()
// src_dst: 0 for src_indices, 1 for dst_indices
Array<PrimExpr> Copy::MakeIndices(const Array<IterVar> &ivs,
int src_dst) const {
Array<PrimExpr> indices;
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
size_t idx = 0;
for (size_t i = 0; i < ranges.size(); i++) {
if (is_one(ranges[i]->extent))
indices.push_back(ranges[i]->min);
else {
indices.push_back(ranges[i]->min + ivs[idx]->var);
idx++;
}
}
ICHECK(idx == ivs.size())
<< "idx = " << idx << ", ivs.size() = " << ivs.size()
<< "src name = " << src->name << ", dst name = " << dst->name;
return indices;
}
PrimExpr Copy::MakePredicate(arith::Analyzer *analyzer,
const Array<IterVar> &ivs, Array<PrimExpr> extents,
int src_dst) const {
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
Array<PrimExpr> cond_list;
ICHECK(extents.size() == ranges.size()) << extents << " " << ranges;
size_t idx = 0;
for (size_t i = 0; i < ranges.size(); i++) {
if (is_one(ranges[i]->extent))
continue;
PrimExpr cond = ranges[i]->min + ivs[idx]->var < extents[i];
if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
cond_list.push_back(cond);
}
cond = ranges[i]->min + ivs[idx]->var >= 0;
if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
cond_list.push_back(cond);
}
idx++;
}
if (cond_list.empty())
return {};
else {
PrimExpr cond = cond_list[0];
for (size_t i = 1; i < cond_list.size(); i++)
cond = And(cond, cond_list[i]);
return cond;
}
}
For Copy::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Array<IterVar> loop_vars = MakeIterVars();
bool is_scalar = loop_vars.size() == 0;
if (is_scalar) {
return For(Var("i"), 0, 1, ForKind::kSerial,
BufferStore(dst, BufferLoad(src, {0}), {0}));
}
for (const auto &iv : loop_vars)
analyzer->Bind(iv->var, iv->dom);
ICHECK(loop_vars.size() <= src_range.size())
<< "loop_vars.size() = " << loop_vars.size()
<< ", src_range.size() = " << src_range.size() << ", src = " << src->name
<< ", dst = " << dst->name;
ICHECK(loop_vars.size() <= dst_range.size())
<< "loop_vars.size() = " << loop_vars.size()
<< ", dst_range.size() = " << dst_range.size() << ", src = " << src->name
<< ", dst = " << dst->name;
Array<PrimExpr> src_indices = MakeIndices(loop_vars, 0);
Array<PrimExpr> dst_indices = MakeIndices(loop_vars, 1);
PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);
PrimExpr value = BufferLoad(src, src_indices);
if (src->dtype != dst->dtype)
value = Cast(dst->dtype, value);
if (src_predicate.defined())
value = if_then_else(src_predicate, value, make_zero(dst->dtype));
Stmt body = BufferStore(dst, value, dst_indices);
if (dst_predicate.defined())
body = IfThenElse(dst_predicate, body);
for (int i = loop_vars.size() - 1; i >= 0; i--) {
Map<String, ObjectRef> annotations = {};
if (coalesced_width.defined()) {
annotations.Set("coalesced_width", coalesced_width);
}
body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
ForKind::kParallel, body, std::nullopt, annotations);
}
return Downcast<For>(body);
}
Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Target target = T.target;
bool is_cpu_target = target->GetTargetDeviceType() == kDLCPU;
Stmt ldsm_stmt = LowerLDSMCopy(T, analyzer);
if (ldsm_stmt.defined())
return ldsm_stmt;
if (!disable_tma) {
Stmt bulk_copy_stmt = LowerBulkCopy(T, analyzer);
if (bulk_copy_stmt.defined())
return bulk_copy_stmt;
}
auto simt_loop = MakeSIMTLoop(analyzer);
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
auto transformed_loop =
Downcast<For>(ParallelLoopTransformer::Substitute(fused_loop));
For vectorized_thread_loop;
auto par_op = std::make_unique<ParallelOp>(transformed_loop);
if (is_cpu_target) {
vectorized_thread_loop = VectorizeLoop(transformed_loop);
} else {
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
InferLevel::kFree};
for (auto level : levels) {
par_op->InferLayout(
{T.target, T.thread_bounds, T.layout_map, T.buffer_remap}, level);
}
auto loop_layout = par_op->GetLoopLayout();
auto thread_var = T.thread_var;
auto thread_loop =
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
vectorized_thread_loop = VectorizeLoop(thread_loop);
}
if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
vectorized_thread_loop);
}
return vectorized_thread_loop;
}
Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
// Check buffer scope
bool is_ldmatrix;
if (TargetHasLdmatrix(T.target) && src.scope() == "shared.dyn" &&
dst.scope() == "local.fragment") {
is_ldmatrix = true;
} else if (TargetHasStmatrix(T.target) && dst.scope() == "shared.dyn" &&
src.scope() == "local.fragment") {
is_ldmatrix = false;
} else {
return Stmt();
}
// Check no predicates
Array<IterVar> loop_vars = MakeIterVars();
if (loop_vars.size() < 2)
return Stmt();
for (const auto &iv : loop_vars)
analyzer->Bind(iv->var, iv->dom);
PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);
if (src_predicate.defined() || dst_predicate.defined())
return Stmt();
Buffer shared_tensor = is_ldmatrix ? src : dst;
Buffer local_tensor = is_ldmatrix ? dst : src;
Array<PrimExpr> local_indices = MakeIndices(loop_vars, is_ldmatrix ? 1 : 0);
Fragment local_layout = Downcast<Fragment>(T.layout_map[local_tensor]);
Array<PrimExpr> local_indices_transformed =
local_layout->Forward(local_indices);
local_tensor = T.buffer_remap[local_tensor];
// currently only support 1-d case
if (local_layout->OutputDim() != 1)
return Stmt();
Array<PrimExpr> shared_indices = MakeIndices(loop_vars, is_ldmatrix ? 0 : 1);
Array<PrimExpr> shared_indices_transformed = shared_indices;
Layout shared_layout;
if (T.buffer_remap.count(shared_tensor)) {
shared_layout = T.layout_map[shared_tensor];
shared_tensor = T.buffer_remap[shared_tensor];
shared_indices_transformed = shared_layout->Forward(shared_indices);
}
// Check local_layout follows 8x8 layout
bool is_transposed;
IterVar col_var = loop_vars[loop_vars.size() - 1];
IterVar row_var = loop_vars[loop_vars.size() - 2];
PrimExpr local_layout_thread_map =
FloorMod(local_layout->ForwardThread(local_indices, std::nullopt), 32);
PrimExpr matrix_8x8_thread_map = makeGemmFragment8x8()->ForwardThread(
{FloorMod(row_var, 8), FloorMod(col_var, 8)}, std::nullopt);
PrimExpr matrix_8x8_thread_map_trans =
makeGemmFragment8x8Transposed()->ForwardThread(
{FloorMod(row_var, 8), FloorMod(col_var, 8)}, std::nullopt);
PrimExpr local_indices_flattened =
local_tensor.OffsetOf(local_indices_transformed).back();
if (analyzer->CanProveEqual(matrix_8x8_thread_map, local_layout_thread_map) &&
IndiceCanVectorize(local_indices_flattened, col_var->var,
col_var->dom->extent, 2, analyzer)) {
is_transposed = false;
} else if (analyzer->CanProveEqual(matrix_8x8_thread_map_trans,
local_layout_thread_map) &&
IndiceCanVectorize(local_indices_flattened, row_var->var,
row_var->dom->extent, 2, analyzer)) {
is_transposed = true;
} else {
return Stmt();
}
// Check shared_layout is 16 bytes continuous
if (shared_tensor->dtype.bytes() != 2)
return Stmt();
PrimExpr flattened_indice =
shared_tensor.OffsetOf(shared_indices_transformed).back();
if (!IndiceCanVectorize(flattened_indice, loop_vars.back()->var,
loop_vars.back()->dom->extent, 8, analyzer))
return Stmt();
// Can only support local_range to be a full range
for (size_t i = 0; i < dst_range.size(); i++) {
if (!is_zero(dst_range[i]->min) ||
!analyzer->CanProveEqual(dst_range[i]->extent, dst->shape[i]))
return Stmt();
}
// Do the lowering here, try vectorized ldmatrix/stmatrix by 4/2/1
PrimExpr extent = local_tensor->shape[0];
int num = 1;
if (analyzer->CanProveEqual(FloorMod(extent, 8), 0))
num = 4;
else if (analyzer->CanProveEqual(FloorMod(extent, 4), 0))
num = 2;
Array<PrimExpr> args;
const Op &op = is_ldmatrix ? tl::ptx_ldmatirx() : tl::ptx_stmatrix();
args.push_back(static_cast<int>(is_transposed));
args.push_back(num);
// Create shared address with regard to local address
// if not transpose
// coords = Inverse(base + 2 * (thread / 8) % num, warp + (thread % 8) * 4))
// if transpose
// coords = Inverse(base + 2 * (thread / 8) % num + thread % 2, warp + thread
// % 8 / 2)
Var local_iter("i");
Layout inv = local_layout->Inverse();
Array<PrimExpr> shared_coords;
PrimExpr warp = FloorDiv(T.thread_var, 32) * 32;
if (!is_transposed)
shared_coords = inv->Forward(
{local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num),
warp + FloorMod(T.thread_var, 8) * 4});
else
shared_coords = inv->Forward(
{local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num) +
FloorMod(T.thread_var, 2),
warp + FloorDiv(FloorMod(T.thread_var, 8), 2)});
shared_coords.pop_back(); // remove rep
if (shared_layout.defined())
shared_coords = shared_layout->Forward(shared_coords);
PrimExpr shared_addr = shared_tensor.access_ptr(
is_ldmatrix ? 1 : 2, DataType::Handle(), 1,
shared_tensor.OffsetOf(shared_coords).back(), PrimExpr(2 * num));
args.push_back(shared_addr);
if (is_ldmatrix) {
// Can only support same dtype for ldmatrx
if (local_tensor->dtype != shared_tensor->dtype)
return Stmt();
PrimExpr local_addr = local_tensor.access_ptr(
2, DataType::Handle(), 1, local_iter * 2 * num, PrimExpr(2 * num));
args.push_back(local_addr);
} else {
for (int i = 0; i < num; i++) {
PrimExpr value0 =
BufferLoad(local_tensor, {local_iter * 2 * num + 2 * i});
PrimExpr value1 =
BufferLoad(local_tensor, {local_iter * 2 * num + 2 * i + 1});
if (local_tensor->dtype != shared_tensor->dtype) {
value0 = Cast(shared_tensor->dtype, value0);
value1 = Cast(shared_tensor->dtype, value1);
}
PrimExpr value_packed =
Call(DataType::Int(32), pack_b16(), {value0, value1});
args.push_back(value_packed);
}
}
auto body = Evaluate(Call(DataType::Handle(), op, args));
For for_node =
For(local_iter, 0, FloorDiv(extent, 2 * num), ForKind::kSerial, body);
for_node = LoopPragmaUnroll(for_node);
auto range = T.thread_bounds;
if (range.defined()) {
auto thread_var = T.thread_var;
auto thread_var_with_offset = thread_var - range->min;
for_node.CopyOnWrite()->body =
Substitute(for_node->body, {{thread_var, thread_var_with_offset}});
}
return for_node;
}
LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) {
// Use parallel op to infer the layout
if (par_op_ == nullptr) {
arith::Analyzer analyzer;
par_op_ = std::make_unique<ParallelOp>(MakeSIMTLoop(&analyzer));
}
return par_op_->InferLayout(T, level);
}
Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
if (args[0]->IsInstance<BufferLoadNode>()) {
......@@ -479,11 +122,6 @@ Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
}
}
TIR_REGISTER_TL_OP(Copy, copy)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_REGISTER_TL_OP(Fill, fill)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
......
......@@ -15,53 +15,6 @@ namespace tl {
using namespace tir;
class Copy : public Operator {
public:
Copy(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op &Get();
Copy(const Copy &other)
: args_(other.args_), src(other.src), dst(other.dst),
src_range(other.src_range), dst_range(other.dst_range),
coalesced_width(other.coalesced_width), disable_tma(other.disable_tma) {
// No clone nullptr
if (other.par_op_)
par_op_ = std::unique_ptr<ParallelOp>(
static_cast<ParallelOp *>(other.par_op_->Clone().release()));
}
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<Copy>(*this);
}
protected:
Stmt LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const;
Stmt LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const;
For MakeSIMTLoop(arith::Analyzer *analyzer) const;
Array<IterVar> MakeIterVars() const;
// ivs: itervars returned by MakeIterVars()
// src_dst: 0 for src_indices, 1 for dst_indices
Array<PrimExpr> MakeIndices(const Array<IterVar> &ivs, int src_dst) const;
PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs,
Array<PrimExpr> extents, int src_dst) const;
Array<PrimExpr> args_;
Buffer src, dst;
Array<Range> src_range, dst_range;
IntImm coalesced_width;
Bool disable_tma = Bool(false);
std::unique_ptr<ParallelOp> par_op_;
int eviction_policy;
};
class Fill : public Operator {
public:
Fill(Array<PrimExpr> args, BufferMap vmap);
......
......@@ -49,7 +49,6 @@ struct LowerArgs {
AddWorkspaceCallback AddWorkspace;
LayoutMap layout_map;
Map<Buffer, Buffer> buffer_remap;
bool disable_tma_lower;
};
struct LayoutInferArgs {
......
......@@ -14,7 +14,6 @@
#include <vector>
#include "../op/builtin.h"
#include "../op/bulk_copy.h"
#include "arith/pattern_match.h"
#include "target/source/ptx.h"
......@@ -1100,7 +1099,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
ss << "tl::tma_store";
}
print_extern_call_stmt(ss.str(), 0, 1);
} else if (op->op.same_as(tl::ptx_ldmatirx())) {
} else if (op->op.same_as(tl::ptx_ldmatrix())) {
int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num);
......
......@@ -14,7 +14,6 @@
#include <vector>
#include "../op/builtin.h"
#include "../op/bulk_copy.h"
#include "target/source/ptx.h"
namespace tvm {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment