Unverified Commit 5c11d245 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Merge bulk copy into copy and improve layout inference for bulk copy (#746)

* [Refactor] Merge bulk copy into copy and refactor layout inference for bulk copy

* Deleted the `bulk_copy` operator implementation and its header file as it is no longer needed.
* Introduced a new function `cuTensorMapType()` to return the data type for CUDA tensor mapping.
* Updated related files to reflect these changes, ensuring that the codebase remains clean and maintainable.

* lint fix

* Fix typos in intrinsic names and remove unused print statement in block_sparse_attn_tilelang.py. Updated references from `ptx_ldmatirx` to `ptx_ldmatrix` across multiple files for consistency.

* remove bulk copy

* Refactor copy and atomic add operations to support TMA lower configuration

- Updated `GetCopyInst` to accept a `disable_tma_lower` parameter, allowing for conditional usage of TMA in bulk load/store operations.
- Modified `Lower` method in `Copy` to incorporate the new TMA configuration.
- Refactored `AtomicAdd::Lower` to streamline layout inference and vectorization logic.
- Removed unused `disable_tma_lower` field from `LowerArgs` structure for clarity.
- Enhanced atomic add vectorization by replacing the buggy implementation with a more robust loop vectorization approach.

* Enhance TMA bulk copy logic in `LowerBulkCopy` method

- Added a condition to set `desc.swizzle` to `CU_TENSOR_MAP_SWIZZLE_NONE` when `shared_layout` matches `linear_layout`, improving clarity in layout handling.
- Updated warning log to provide more detailed information about fallback scenarios, including source and destination buffer names and shapes, enhancing debugging capabilities.

* lint fix

* Remove fallback logging for non-swizzled global layout in `LowerBulkCopy` method to streamline the bulk copy logic. This change enhances code clarity by eliminating unnecessary warning messages related to inner box dimensions.

* Enhance reshape kernel compilation in `run_reshape` and `run_reshape_smem_1d_2_2d` functions

- Updated the `tl.compile` method to include `pass_configs` that disable TMA lower and warp specialization, addressing shared memory layout transformation limitations.
- Added TODO comments to indicate the need for further improvements in shared memory handling.

* Update `native_sparse_attention` function to include TMA configuration options

- Added `pass_configs` to the JIT decorator to disable TMA lower and warp specialization, addressing potential issues with shared memory layout transformations.
- Updated comments to clarify modifications in tensor shapes for inference, specifically setting `q` sequence length to 1.

* Refactor JIT decorator formatting in `native_sparse_attention` function

- Improved readability by reformatting the JIT decorator parameters for `native_sparse_attention`, ensuring consistent style across the codebase.
- No functional changes were made; this update focuses on code clarity and maintainability.

* Enhance thread management and logging in TileLang compilation

- Added a method to check if printing is enabled during compilation, improving control over logging behavior.
- Updated the JIT kernel class to utilize the new method for logging compilation status, ensuring consistent and clear output.
- Added comments to clarify the purpose of changes and improve code readability.

* Add warp specialization scope and refactor register management in TileLang

- Introduced a new constant `kWarpSpecializationScope` in `builtin.h` for better attribute management.
- Removed the `SetMaxNRegCollector` class and its related logic from `warp_specialized_rewriter.cc`, streamlining the warp specialization process.
- Added functions `annotate_producer_reg_dealloc` and `annotate_consumer_reg_alloc` in `builtin.py` to facilitate register management.
- Implemented `AnnotateWarpGroupRegAlloc` in `__init__.py` to inject register allocation calls into warp-specialized functions, enhancing the overall register handling in the compilation process.

* Refactor test for InjectSetMaxNReg pass in TileLang

- Improved readability by restructuring conditional checks and assertions in the test cases.
- Enhanced clarity in the collection of `set_max_nreg` calls by simplifying the logic.
- Ensured consistent formatting and spacing throughout the test functions for better maintainability.

* Enhance bulk copy and store checks in `Copy` class

- Updated scope validation for source and destination tensors in `CheckBulkLoad` and `CheckBulkStore` methods to include both `shared.dyn` and `shared` as valid options.
- Modified `CheckLDSMCopy` and `CheckSTSMCopy` methods to accommodate the new scope validation, ensuring compatibility with shared memory configurations.
- Improved logging in `LowerBulkCopy` to provide clearer warnings regarding unsupported swizzle layouts, including source and destination names for better debugging.

* lint fix
parent cb37bfef
...@@ -192,7 +192,7 @@ def matmul_sp(M, N, K): ...@@ -192,7 +192,7 @@ def matmul_sp(M, N, K):
# Clear out the accumulation buffer # Clear out the accumulation buffer
T.clear(C_local) T.clear(C_local)
T.no_set_max_nreg() T.disable_warp_group_reg_alloc()
T.use_swizzle(panel_size=10, enable=enable_rasterization) T.use_swizzle(panel_size=10, enable=enable_rasterization)
T.annotate_layout({ T.annotate_layout({
......
...@@ -52,7 +52,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -52,7 +52,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.no_set_max_nreg() T.disable_warp_group_reg_alloc()
loop_range = T.ceildiv(seqlen_kv, block_N) loop_range = T.ceildiv(seqlen_kv, block_N)
for k in T.Pipelined(loop_range, num_stages=2): for k in T.Pipelined(loop_range, num_stages=2):
T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], qKV_shared) T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], qKV_shared)
......
...@@ -8,7 +8,14 @@ import tilelang.testing ...@@ -8,7 +8,14 @@ import tilelang.testing
tilelang.testing.set_random_seed(42) tilelang.testing.set_random_seed(42)
@tilelang.jit(out_idx=[-1]) # TODO(lei): workaround, as threads is not divisible by warp group size,
# auto warp specialization may have some bugs.
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
def native_sparse_attention( def native_sparse_attention(
batch, batch,
heads, heads,
...@@ -22,7 +29,7 @@ def native_sparse_attention( ...@@ -22,7 +29,7 @@ def native_sparse_attention(
if scale is None: if scale is None:
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
# Modified shapes for inference (q has seq_len=1) # Modified shapes for inference (q has seq_len=1)a
q_shape = [batch, 1, heads, dim] # Changed seq_len to 1 q_shape = [batch, 1, heads, dim] # Changed seq_len to 1
kv_shape = [batch, seq_len, head_kv, dim] kv_shape = [batch, seq_len, head_kv, dim]
block_indices_shape = [batch, 1, head_kv, selected_blocks] # Changed seq_len to 1 block_indices_shape = [batch, 1, head_kv, selected_blocks] # Changed seq_len to 1
...@@ -167,8 +174,6 @@ def main(): ...@@ -167,8 +174,6 @@ def main():
block_counts=block_counts, block_counts=block_counts,
block_size=block_size, block_size=block_size,
) )
print("out", out)
print("ref", ref)
torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2)
......
...@@ -338,7 +338,7 @@ def matmul(M, ...@@ -338,7 +338,7 @@ def matmul(M,
C_shared: tilelang.layout.make_swizzled_layout(C_shared), C_shared: tilelang.layout.make_swizzled_layout(C_shared),
}) })
if threads == 512: if threads == 512:
T.no_set_max_nreg() T.disable_warp_group_reg_alloc()
T.clear(C_local) T.clear(C_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages): for k in T.Pipelined(K // block_K, num_stages=num_stages):
......
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tilelang import tilelang
from tilelang.autotuner import *
import tilelang.language as T import tilelang.language as T
import argparse import argparse
...@@ -340,11 +339,10 @@ def main(BATCH: int = 1, ...@@ -340,11 +339,10 @@ def main(BATCH: int = 1,
dK_ref, K.grad = K.grad.clone(), None dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None dV_ref, V.grad = V.grad.clone(), None
assert torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
def run(): def run():
O_ref.backward(dO, retain_graph=True) O_ref.backward(dO, retain_graph=True)
......
...@@ -122,7 +122,7 @@ def tilelang_chunk_fwd_o( ...@@ -122,7 +122,7 @@ def tilelang_chunk_fwd_o(
T.clear(A_fragment) T.clear(A_fragment)
T.clear(O_fragment) T.clear(O_fragment)
T.no_set_max_nreg() T.disable_warp_group_reg_alloc()
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy( T.copy(
Q[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], Q[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
......
...@@ -101,7 +101,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd( ...@@ -101,7 +101,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
}) })
T.fill(A_fragment, 0) T.fill(A_fragment, 0)
T.no_set_max_nreg() T.disable_warp_group_reg_alloc()
for i_s in T.Parallel(block_S): for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
......
...@@ -107,7 +107,7 @@ def tilelang_recompute_w_u_fwd( ...@@ -107,7 +107,7 @@ def tilelang_recompute_w_u_fwd(
U_Beta_shared: tilelang.layout.make_swizzled_layout(U_Beta_shared), U_Beta_shared: tilelang.layout.make_swizzled_layout(U_Beta_shared),
}) })
T.no_set_max_nreg() T.disable_warp_group_reg_alloc()
for i_s in T.Parallel(block_S): for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
G_shared[i_s] = T.exp(G[bb, bs * block_S + i_s, bh]) G_shared[i_s] = T.exp(G[bb, bs * block_S + i_s, bh])
......
...@@ -178,7 +178,6 @@ def test_topk_sparse_attention(): ...@@ -178,7 +178,6 @@ def test_topk_sparse_attention():
# Run tilelang kernel # Run tilelang kernel
kernel = blocksparse_flashattn( kernel = blocksparse_flashattn(
BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
print(kernel.get_kernel_source())
tilelang_output = kernel(q, k, v, block_mask.to(torch.int8)) tilelang_output = kernel(q, k, v, block_mask.to(torch.int8))
# Compute reference # Compute reference
......
...@@ -182,27 +182,25 @@ For AtomicAdd::MakeSIMTLoop(arith::Analyzer *analyzer) const { ...@@ -182,27 +182,25 @@ For AtomicAdd::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Stmt AtomicAdd::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Stmt AtomicAdd::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Target target = T.target; Target target = T.target;
bool is_cpu_target = target->GetTargetDeviceType() == kDLCPU;
auto simt_loop = MakeSIMTLoop(analyzer); auto simt_loop = MakeSIMTLoop(analyzer);
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop)); auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
For vectorized_thread_loop;
auto par_op = std::make_unique<ParallelOp>(fused_loop); auto par_op = std::make_unique<ParallelOp>(fused_loop);
if (!is_cpu_target) { std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict, InferLevel::kFree};
InferLevel::kFree}; for (auto level : levels) {
for (auto level : levels) { par_op->InferLayout(
par_op->InferLayout( {T.target, T.thread_bounds, T.layout_map, T.buffer_remap}, level);
{T.target, T.thread_bounds, T.layout_map, T.buffer_remap}, level);
}
auto loop_layout = par_op->GetLoopLayout();
Var thread_var = T.thread_var;
Range thread_bounds = T.thread_bounds;
auto thread_loop =
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
vectorized_thread_loop = VectorizeAtomicAdd(
thread_loop, thread_var, thread_bounds, GetArchInt(target));
} }
auto loop_layout = par_op->GetLoopLayout();
Var thread_var = T.thread_var;
Range thread_bounds = T.thread_bounds;
auto thread_loop =
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
// TODO(@dyq): buggy implementation, need to fix
// vectorized_thread_loop = VectorizeAtomicAdd(
// thread_loop, thread_var, thread_bounds, GetArchInt(target));
auto vectorized_thread_loop = VectorizeLoop(thread_loop);
if (par_op->GetPredicate(T.thread_var).defined()) { if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(), return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
......
...@@ -29,6 +29,8 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kPtxasRegisterUsageLevel, Integer); ...@@ -29,6 +29,8 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kPtxasRegisterUsageLevel, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnablePTXASVerboseOutput, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kEnablePTXASVerboseOutput, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool);
DataType cuTensorMapType() { return DataType::UInt(8, 128); }
#define TIR_DEFINE_TL_BUILTIN(OpName) \ #define TIR_DEFINE_TL_BUILTIN(OpName) \
const Op &OpName() { \ const Op &OpName() { \
static const Op &op = Op::Get("tl." #OpName); \ static const Op &op = Op::Get("tl." #OpName); \
...@@ -78,7 +80,7 @@ TIR_DEFINE_TL_BUILTIN(mbarrier_expect_tx) ...@@ -78,7 +80,7 @@ TIR_DEFINE_TL_BUILTIN(mbarrier_expect_tx)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_ldmatirx) TIR_DEFINE_TL_BUILTIN(ptx_ldmatrix)
.set_num_inputs(4) .set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
......
...@@ -15,6 +15,8 @@ namespace tl { ...@@ -15,6 +15,8 @@ namespace tl {
namespace attr { namespace attr {
static constexpr const char *kPaddingMap = "padding_map"; static constexpr const char *kPaddingMap = "padding_map";
static constexpr const char *kWarpSpecializationScope =
"kWarpSpecializationScope";
} // namespace attr } // namespace attr
static constexpr const char *kDebugMergeSharedMemoryAllocations = static constexpr const char *kDebugMergeSharedMemoryAllocations =
...@@ -54,6 +56,14 @@ static constexpr const char *kDisableDynamicTailSplit = ...@@ -54,6 +56,14 @@ static constexpr const char *kDisableDynamicTailSplit =
*/ */
static constexpr const char *kDynamicAlignment = "tl.dynamic_alignment"; static constexpr const char *kDynamicAlignment = "tl.dynamic_alignment";
/*!
* \brief Get the type of the CUDA tensor map
*
* DataType cuTensorMapType()
*
*/
DataType cuTensorMapType();
/*! /*!
* \brief tvm intrinsics for TMADescriptor creation for tiled load * \brief tvm intrinsics for TMADescriptor creation for tiled load
* *
...@@ -138,15 +148,15 @@ TVM_DLL const Op &mbarrier_expect_tx(); ...@@ -138,15 +148,15 @@ TVM_DLL const Op &mbarrier_expect_tx();
/*! /*!
* \brief tvm intrinsics for ldmatrix * \brief tvm intrinsics for ldmatrix
* *
* ptx_ldmatirx(transposed, num, shared_addr, local_addr) * ptx_ldmatrix(transposed, num, shared_addr, local_addr)
* *
*/ */
TVM_DLL const Op &ptx_ldmatirx(); TVM_DLL const Op &ptx_ldmatrix();
/*! /*!
* \brief tvm intrinsics for stmatrix * \brief tvm intrinsics for stmatrix
* *
* ptx_ldmatirx(transposed, num, shared_addr, int32_values...) * ptx_ldmatrix(transposed, num, shared_addr, int32_values...)
* *
*/ */
TVM_DLL const Op &ptx_stmatrix(); TVM_DLL const Op &ptx_stmatrix();
......
/*!
* \file tl/op/bulk_copy.h
* \brief Bulk copy operator.
*
*/
#ifndef TVM_TL_OP_BULK_COPY_H_
#define TVM_TL_OP_BULK_COPY_H_
#include "elem.h"
namespace tvm {
namespace tl {
using namespace tir;
struct TMADesc {
size_t rank;
int data_type;
Array<PrimExpr> global_shape, global_stride;
Array<PrimExpr> smem_box, smem_stride;
PrimExpr global_addr;
int swizzle;
int interleave;
int oob_fill;
int l2_promotion;
Array<PrimExpr> EncodeCallArgs() const;
};
DataType cuTensorMapType();
struct TMAIm2ColDesc {
size_t rank;
int data_type;
Array<PrimExpr> global_shape, global_stride, elem_stride; // rank
Array<PrimExpr> lower_corner, upper_corner; // rank - 2
PrimExpr global_addr;
int smem_box_pixel, smem_box_channel;
int swizzle;
int interleave;
int oob_fill;
int l2_promotion;
Array<PrimExpr> EncodeCallArgs() const;
};
class Conv2DIm2ColOp : public Operator {
public:
Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
static const Op &Get();
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<Conv2DIm2ColOp>(*this);
}
private:
Buffer src, dst;
int stride, padding, dilation, kernel, eviction_policy;
PrimExpr nhw_step, c_step;
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_BULK_COPY_H_
\ No newline at end of file
This diff is collapsed.
/*!
* \file tl/op/elem.h
* \brief Define element-wise and copy-related operators for TVM TensorIR
* Lowering.
*
* This header declares the Copy operator and related operator descriptors
* such as TMADesc and TMAIm2ColDesc, as well as a Conv2DIm2Col special
* operator.
*/
#ifndef TVM_TL_OP_COPY_H_
#define TVM_TL_OP_COPY_H_
#include "op.h"
#include "parallel.h"
namespace tvm {
namespace tl {
using namespace tir;
/*!
* \brief Descriptor for Tensor Memory Access (TMA) copy operations.
*
* Contains meta-information required to perform global-to-shared memory copy
* using Tensor Memory Accelerator (TMA) hardware instructions. It is mainly
* used to describe the shape, strides, and data layout for both source and
* shared memory buffers.
*/
struct TMADesc {
size_t rank; // Tensor rank (number of dimensions)
int data_type; // Data type identifier (numeric code)
Array<PrimExpr> global_shape; // Shape of the source tensor in global memory
Array<PrimExpr>
global_stride; // Strides of the source tensor in global memory
Array<PrimExpr> smem_box; // Block shape in shared memory
Array<PrimExpr> smem_stride; // Strides in shared memory layout
PrimExpr global_addr; // Base address in global memory
int swizzle; // Swizzle parameter for memory layout transform
int interleave; // Interleave parameter for optimization
int oob_fill; // Out-of-bound fill policy
int l2_promotion; // Whether to promote data to L2 cache
/*!
* \brief Encode descriptor fields into an argument array for runtime calls.
*/
Array<PrimExpr> EncodeCallArgs() const;
};
/*!
* \brief Descriptor for TMA-based im2col transformation used in Conv2D.
*
* This supports extracting patches from the input image (im2col)
* for convolution lowering, storing them in shared memory.
*/
struct TMAIm2ColDesc {
size_t rank; // Rank of the tensor
int data_type; // Data type identifier
Array<PrimExpr> global_shape; // Shape of input tensor in global memory
Array<PrimExpr> global_stride; // Stride in global memory
Array<PrimExpr> elem_stride; // Stride at element level (per axis)
Array<PrimExpr> lower_corner; // Lower bound offsets for the extraction window
// (rank - 2 dims)
Array<PrimExpr> upper_corner; // Upper bound offsets for the extraction window
// (rank - 2 dims)
PrimExpr global_addr; // Base address in global memory
int smem_box_pixel; // Pixel dimension of shared memory box
int smem_box_channel; // Channel dimension of shared memory box
int swizzle; // Memory swizzle setting
int interleave; // Memory interleaving setting
int oob_fill; // Out-of-bound fill policy
int l2_promotion; // Whether to enable L2 cache promotion
/*!
* \brief Encode descriptor fields into runtime arguments.
*/
Array<PrimExpr> EncodeCallArgs() const;
};
/*!
* \brief Copy operator for transferring data between buffers.
*
* This class implements a generic copy operator in TensorIR Lowering for
* block-wise or element-wise data transfer, possibly optimized with
* parallelization or TMA hardware acceleration.
*/
class Copy : public Operator {
public:
/*!
* \brief Constructor.
* \param args Expression arguments for the copy.
* \param vmap Buffer variable mapping.
*/
Copy(Array<PrimExpr> args, BufferMap vmap);
/*!
* \brief Lower the copy operator to a TIR statement.
* \param T Arguments for lowering.
* \param analyzer Analyzer for simplification and bounds checks.
*/
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
/*!
* \brief Infer buffer layouts after applying this operator.
* \param T Arguments for layout inference.
* \param level Level of inference (basic or detailed).
*/
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
/*!
* \brief Get the TVM Op handle corresponding to this Copy op.
*/
static const Op &Get();
/*!
* \brief Copy instruction type.
*/
enum class CopyInst {
kNormal = 0, // utilize ldg/stg or cpasync or any buffer copy
kLDSM = 1, // ldmatrix memory copy
kSTSM = 2, // stmatrix memory copy
kBulkLoad = 3, // utilize tma load
kBulkStore = 4, // utilize tma store
};
/*!
* \brief Check if bulk copy is supported.
*/
bool CheckBulkLoad(Target target) const;
/*!
* \brief Check if bulk store is supported.
*/
bool CheckBulkStore(Target target) const;
/*!
* \brief Check if lds memory copy is supported.
*/
bool CheckLDSMCopy(Target target) const;
/*!
* \brief Check if stsm memory copy is supported.
*/
bool CheckSTSMCopy(Target target) const;
/*!
* \brief Get the copy instruction type.
*/
CopyInst GetCopyInst(Target target, bool disable_tma_lower) const;
/*!
* \brief Copy constructor (deep clones ParallelOp if present).
*/
Copy(const Copy &other)
: args_(other.args_), src(other.src), dst(other.dst),
src_range(other.src_range), dst_range(other.dst_range),
coalesced_width(other.coalesced_width), disable_tma(other.disable_tma) {
// Deep copy ParallelOp if it exists
if (other.par_op_)
par_op_ = std::unique_ptr<ParallelOp>(
static_cast<ParallelOp *>(other.par_op_->Clone().release()));
}
/*!
* \brief Clone this copy operator.
*/
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<Copy>(*this);
}
protected:
/*!
* \brief Generate lowering for bulk/global-to-shared copy.
*/
Stmt LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
CopyInst copy_inst) const;
/*!
* \brief Generate lowering for LDS Memory Copy (shared memory to shared
* memory or smem usage).
*/
Stmt LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer,
CopyInst copy_inst) const;
/*!
* \brief Generate lowering for normal copy.
*/
Stmt LowerNormalCopy(const LowerArgs &T, arith::Analyzer *analyzer) const;
/*!
* \brief Generate SIMT (thread-level) loop for copying.
*/
For MakeSIMTLoop(arith::Analyzer *analyzer) const;
/*!
* \brief Compute linear layout for tma copy.
*/
Layout ComputeLinearLayout(const Buffer &shared_tensor) const;
/*!
* \brief Create iterator variables for multi-dimensional copy loops.
*/
Array<IterVar> MakeIterVars() const;
/*!
* \brief Calculate source or destination indices from iteration vars.
* \param ivs Iterator variables from MakeIterVars().
* \param src_dst 0 = make source indices, 1 = make destination indices.
*/
Array<PrimExpr> MakeIndices(const Array<IterVar> &ivs, int src_dst) const;
/*!
* \brief Construct the boundary predicate for valid copy (to avoid OOB).
* \param analyzer Arithmetic analyser for simplification.
* \param ivs Iterator variables.
* \param extents Extent expressions for the relevant buffer.
* \param src_dst 0 = predicate for source, 1 = predicate for destination.
*/
PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs,
Array<PrimExpr> extents, int src_dst) const;
Array<PrimExpr> args_; // Copy parameters (indices, sizes, etc.)
Buffer src, dst; // Source and destination buffers
Array<Range> src_range, dst_range; // Ranges for each dimension in src and dst
IntImm coalesced_width; // Width (in elements) for coalesced memory access
Bool disable_tma = Bool(false); // Whether to disable TMA acceleration
std::unique_ptr<ParallelOp>
par_op_; // Optional associated parallelization operator
enum class EvictionPolicy {
kEvictNormal = 0,
kEvictFirst = 1,
kEvictLast = 2,
};
int eviction_policy; // Policy for cache eviction
};
/*!
* \brief Special operator for Conv2D im2col transformation.
*
* This operator converts input image layout into columnar format suitable
* for matrix multiplication-based convolution lowering.
*/
class Conv2DIm2ColOp : public Operator {
public:
/*!
* \brief Constructor.
* \param args Op arguments (convolution parameters, shapes, etc.)
* \param vmap Variable buffer mapping.
*/
Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap);
/*!
* \brief Lower to TIR statement.
*/
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
/*!
* \brief Get TVM Op handle.
*/
static const Op &Get();
/*!
* \brief Clone this operator.
*/
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<Conv2DIm2ColOp>(*this);
}
private:
Buffer src, dst; // Source (input feature map) and destination (im2col matrix)
int stride; // Stride for convolution
int padding; // Padding amount
int dilation; // Dilation factor
int kernel; // Kernel size
int eviction_policy; // Cache eviction policy
PrimExpr nhw_step; // Step size in NHW dimensions
PrimExpr c_step; // Step size in channel dimension
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_COPY_H_
\ No newline at end of file
...@@ -22,363 +22,6 @@ namespace tl { ...@@ -22,363 +22,6 @@ namespace tl {
using namespace tir; using namespace tir;
Copy::Copy(Array<PrimExpr> args, BufferMap vmap) : args_(args) {
Array<Range> rgs[2];
Buffer bf[2];
for (int i = 0; i < 2; i++) {
auto expr = args[i];
auto call = expr.as<CallNode>();
ICHECK(call);
auto region = RegionOp(call->args, vmap);
rgs[i] = region.GetRanges();
bf[i] = region.GetBuffer();
}
std::tie(this->src, this->dst) = std::tie(bf[0], bf[1]);
std::tie(this->src_range, this->dst_range) = std::tie(rgs[0], rgs[1]);
if (args.size() >= 3) {
auto coalesced_width = Downcast<IntImm>(args[2]);
if (coalesced_width->value > 0) {
this->coalesced_width = coalesced_width;
}
}
if (args.size() >= 4) {
auto disable_tma = Downcast<Bool>(args[3]);
this->disable_tma = disable_tma;
}
if (args.size() >= 5) {
this->eviction_policy = args[4].as<IntImmNode>()->value;
}
}
Array<IterVar> Copy::MakeIterVars() const {
Array<IterVar> loop_vars;
size_t idx = 0;
for (size_t i = 0; i < src_range.size(); i++) {
if (is_one(src_range[i]->extent))
continue;
Var var = Var(std::string{char('i' + idx)}, src_range[i]->extent->dtype);
idx++;
loop_vars.push_back(
{Range(0, src_range[i]->extent), var, IterVarType::kDataPar});
}
return loop_vars;
}
// ivs: itervars returned by MakeIterVars()
// src_dst: 0 for src_indices, 1 for dst_indices
Array<PrimExpr> Copy::MakeIndices(const Array<IterVar> &ivs,
int src_dst) const {
Array<PrimExpr> indices;
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
size_t idx = 0;
for (size_t i = 0; i < ranges.size(); i++) {
if (is_one(ranges[i]->extent))
indices.push_back(ranges[i]->min);
else {
indices.push_back(ranges[i]->min + ivs[idx]->var);
idx++;
}
}
ICHECK(idx == ivs.size())
<< "idx = " << idx << ", ivs.size() = " << ivs.size()
<< "src name = " << src->name << ", dst name = " << dst->name;
return indices;
}
PrimExpr Copy::MakePredicate(arith::Analyzer *analyzer,
const Array<IterVar> &ivs, Array<PrimExpr> extents,
int src_dst) const {
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
Array<PrimExpr> cond_list;
ICHECK(extents.size() == ranges.size()) << extents << " " << ranges;
size_t idx = 0;
for (size_t i = 0; i < ranges.size(); i++) {
if (is_one(ranges[i]->extent))
continue;
PrimExpr cond = ranges[i]->min + ivs[idx]->var < extents[i];
if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
cond_list.push_back(cond);
}
cond = ranges[i]->min + ivs[idx]->var >= 0;
if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
cond_list.push_back(cond);
}
idx++;
}
if (cond_list.empty())
return {};
else {
PrimExpr cond = cond_list[0];
for (size_t i = 1; i < cond_list.size(); i++)
cond = And(cond, cond_list[i]);
return cond;
}
}
For Copy::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Array<IterVar> loop_vars = MakeIterVars();
bool is_scalar = loop_vars.size() == 0;
if (is_scalar) {
return For(Var("i"), 0, 1, ForKind::kSerial,
BufferStore(dst, BufferLoad(src, {0}), {0}));
}
for (const auto &iv : loop_vars)
analyzer->Bind(iv->var, iv->dom);
ICHECK(loop_vars.size() <= src_range.size())
<< "loop_vars.size() = " << loop_vars.size()
<< ", src_range.size() = " << src_range.size() << ", src = " << src->name
<< ", dst = " << dst->name;
ICHECK(loop_vars.size() <= dst_range.size())
<< "loop_vars.size() = " << loop_vars.size()
<< ", dst_range.size() = " << dst_range.size() << ", src = " << src->name
<< ", dst = " << dst->name;
Array<PrimExpr> src_indices = MakeIndices(loop_vars, 0);
Array<PrimExpr> dst_indices = MakeIndices(loop_vars, 1);
PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);
PrimExpr value = BufferLoad(src, src_indices);
if (src->dtype != dst->dtype)
value = Cast(dst->dtype, value);
if (src_predicate.defined())
value = if_then_else(src_predicate, value, make_zero(dst->dtype));
Stmt body = BufferStore(dst, value, dst_indices);
if (dst_predicate.defined())
body = IfThenElse(dst_predicate, body);
for (int i = loop_vars.size() - 1; i >= 0; i--) {
Map<String, ObjectRef> annotations = {};
if (coalesced_width.defined()) {
annotations.Set("coalesced_width", coalesced_width);
}
body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
ForKind::kParallel, body, std::nullopt, annotations);
}
return Downcast<For>(body);
}
Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Target target = T.target;
bool is_cpu_target = target->GetTargetDeviceType() == kDLCPU;
Stmt ldsm_stmt = LowerLDSMCopy(T, analyzer);
if (ldsm_stmt.defined())
return ldsm_stmt;
if (!disable_tma) {
Stmt bulk_copy_stmt = LowerBulkCopy(T, analyzer);
if (bulk_copy_stmt.defined())
return bulk_copy_stmt;
}
auto simt_loop = MakeSIMTLoop(analyzer);
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
auto transformed_loop =
Downcast<For>(ParallelLoopTransformer::Substitute(fused_loop));
For vectorized_thread_loop;
auto par_op = std::make_unique<ParallelOp>(transformed_loop);
if (is_cpu_target) {
vectorized_thread_loop = VectorizeLoop(transformed_loop);
} else {
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
InferLevel::kFree};
for (auto level : levels) {
par_op->InferLayout(
{T.target, T.thread_bounds, T.layout_map, T.buffer_remap}, level);
}
auto loop_layout = par_op->GetLoopLayout();
auto thread_var = T.thread_var;
auto thread_loop =
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
vectorized_thread_loop = VectorizeLoop(thread_loop);
}
if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
vectorized_thread_loop);
}
return vectorized_thread_loop;
}
Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
// Check buffer scope
bool is_ldmatrix;
if (TargetHasLdmatrix(T.target) && src.scope() == "shared.dyn" &&
dst.scope() == "local.fragment") {
is_ldmatrix = true;
} else if (TargetHasStmatrix(T.target) && dst.scope() == "shared.dyn" &&
src.scope() == "local.fragment") {
is_ldmatrix = false;
} else {
return Stmt();
}
// Check no predicates
Array<IterVar> loop_vars = MakeIterVars();
if (loop_vars.size() < 2)
return Stmt();
for (const auto &iv : loop_vars)
analyzer->Bind(iv->var, iv->dom);
PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);
if (src_predicate.defined() || dst_predicate.defined())
return Stmt();
Buffer shared_tensor = is_ldmatrix ? src : dst;
Buffer local_tensor = is_ldmatrix ? dst : src;
Array<PrimExpr> local_indices = MakeIndices(loop_vars, is_ldmatrix ? 1 : 0);
Fragment local_layout = Downcast<Fragment>(T.layout_map[local_tensor]);
Array<PrimExpr> local_indices_transformed =
local_layout->Forward(local_indices);
local_tensor = T.buffer_remap[local_tensor];
// currently only support 1-d case
if (local_layout->OutputDim() != 1)
return Stmt();
Array<PrimExpr> shared_indices = MakeIndices(loop_vars, is_ldmatrix ? 0 : 1);
Array<PrimExpr> shared_indices_transformed = shared_indices;
Layout shared_layout;
if (T.buffer_remap.count(shared_tensor)) {
shared_layout = T.layout_map[shared_tensor];
shared_tensor = T.buffer_remap[shared_tensor];
shared_indices_transformed = shared_layout->Forward(shared_indices);
}
// Check local_layout follows 8x8 layout
bool is_transposed;
IterVar col_var = loop_vars[loop_vars.size() - 1];
IterVar row_var = loop_vars[loop_vars.size() - 2];
PrimExpr local_layout_thread_map =
FloorMod(local_layout->ForwardThread(local_indices, std::nullopt), 32);
PrimExpr matrix_8x8_thread_map = makeGemmFragment8x8()->ForwardThread(
{FloorMod(row_var, 8), FloorMod(col_var, 8)}, std::nullopt);
PrimExpr matrix_8x8_thread_map_trans =
makeGemmFragment8x8Transposed()->ForwardThread(
{FloorMod(row_var, 8), FloorMod(col_var, 8)}, std::nullopt);
PrimExpr local_indices_flattened =
local_tensor.OffsetOf(local_indices_transformed).back();
if (analyzer->CanProveEqual(matrix_8x8_thread_map, local_layout_thread_map) &&
IndiceCanVectorize(local_indices_flattened, col_var->var,
col_var->dom->extent, 2, analyzer)) {
is_transposed = false;
} else if (analyzer->CanProveEqual(matrix_8x8_thread_map_trans,
local_layout_thread_map) &&
IndiceCanVectorize(local_indices_flattened, row_var->var,
row_var->dom->extent, 2, analyzer)) {
is_transposed = true;
} else {
return Stmt();
}
// Check shared_layout is 16 bytes continuous
if (shared_tensor->dtype.bytes() != 2)
return Stmt();
PrimExpr flattened_indice =
shared_tensor.OffsetOf(shared_indices_transformed).back();
if (!IndiceCanVectorize(flattened_indice, loop_vars.back()->var,
loop_vars.back()->dom->extent, 8, analyzer))
return Stmt();
// Can only support local_range to be a full range
for (size_t i = 0; i < dst_range.size(); i++) {
if (!is_zero(dst_range[i]->min) ||
!analyzer->CanProveEqual(dst_range[i]->extent, dst->shape[i]))
return Stmt();
}
// Do the lowering here, try vectorized ldmatrix/stmatrix by 4/2/1
PrimExpr extent = local_tensor->shape[0];
int num = 1;
if (analyzer->CanProveEqual(FloorMod(extent, 8), 0))
num = 4;
else if (analyzer->CanProveEqual(FloorMod(extent, 4), 0))
num = 2;
Array<PrimExpr> args;
const Op &op = is_ldmatrix ? tl::ptx_ldmatirx() : tl::ptx_stmatrix();
args.push_back(static_cast<int>(is_transposed));
args.push_back(num);
// Create shared address with regard to local address
// if not transpose
// coords = Inverse(base + 2 * (thread / 8) % num, warp + (thread % 8) * 4))
// if transpose
// coords = Inverse(base + 2 * (thread / 8) % num + thread % 2, warp + thread
// % 8 / 2)
Var local_iter("i");
Layout inv = local_layout->Inverse();
Array<PrimExpr> shared_coords;
PrimExpr warp = FloorDiv(T.thread_var, 32) * 32;
if (!is_transposed)
shared_coords = inv->Forward(
{local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num),
warp + FloorMod(T.thread_var, 8) * 4});
else
shared_coords = inv->Forward(
{local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num) +
FloorMod(T.thread_var, 2),
warp + FloorDiv(FloorMod(T.thread_var, 8), 2)});
shared_coords.pop_back(); // remove rep
if (shared_layout.defined())
shared_coords = shared_layout->Forward(shared_coords);
PrimExpr shared_addr = shared_tensor.access_ptr(
is_ldmatrix ? 1 : 2, DataType::Handle(), 1,
shared_tensor.OffsetOf(shared_coords).back(), PrimExpr(2 * num));
args.push_back(shared_addr);
if (is_ldmatrix) {
// Can only support same dtype for ldmatrx
if (local_tensor->dtype != shared_tensor->dtype)
return Stmt();
PrimExpr local_addr = local_tensor.access_ptr(
2, DataType::Handle(), 1, local_iter * 2 * num, PrimExpr(2 * num));
args.push_back(local_addr);
} else {
for (int i = 0; i < num; i++) {
PrimExpr value0 =
BufferLoad(local_tensor, {local_iter * 2 * num + 2 * i});
PrimExpr value1 =
BufferLoad(local_tensor, {local_iter * 2 * num + 2 * i + 1});
if (local_tensor->dtype != shared_tensor->dtype) {
value0 = Cast(shared_tensor->dtype, value0);
value1 = Cast(shared_tensor->dtype, value1);
}
PrimExpr value_packed =
Call(DataType::Int(32), pack_b16(), {value0, value1});
args.push_back(value_packed);
}
}
auto body = Evaluate(Call(DataType::Handle(), op, args));
For for_node =
For(local_iter, 0, FloorDiv(extent, 2 * num), ForKind::kSerial, body);
for_node = LoopPragmaUnroll(for_node);
auto range = T.thread_bounds;
if (range.defined()) {
auto thread_var = T.thread_var;
auto thread_var_with_offset = thread_var - range->min;
for_node.CopyOnWrite()->body =
Substitute(for_node->body, {{thread_var, thread_var_with_offset}});
}
return for_node;
}
LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) {
// Use parallel op to infer the layout
if (par_op_ == nullptr) {
arith::Analyzer analyzer;
par_op_ = std::make_unique<ParallelOp>(MakeSIMTLoop(&analyzer));
}
return par_op_->InferLayout(T, level);
}
Fill::Fill(Array<PrimExpr> args, BufferMap vmap) { Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
if (args[0]->IsInstance<BufferLoadNode>()) { if (args[0]->IsInstance<BufferLoadNode>()) {
...@@ -479,11 +122,6 @@ Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -479,11 +122,6 @@ Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
} }
} }
TIR_REGISTER_TL_OP(Copy, copy)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_REGISTER_TL_OP(Fill, fill) TIR_REGISTER_TL_OP(Fill, fill)
.set_num_inputs(2) .set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
......
...@@ -15,53 +15,6 @@ namespace tl { ...@@ -15,53 +15,6 @@ namespace tl {
using namespace tir; using namespace tir;
class Copy : public Operator {
public:
Copy(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op &Get();
Copy(const Copy &other)
: args_(other.args_), src(other.src), dst(other.dst),
src_range(other.src_range), dst_range(other.dst_range),
coalesced_width(other.coalesced_width), disable_tma(other.disable_tma) {
// No clone nullptr
if (other.par_op_)
par_op_ = std::unique_ptr<ParallelOp>(
static_cast<ParallelOp *>(other.par_op_->Clone().release()));
}
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<Copy>(*this);
}
protected:
Stmt LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const;
Stmt LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const;
For MakeSIMTLoop(arith::Analyzer *analyzer) const;
Array<IterVar> MakeIterVars() const;
// ivs: itervars returned by MakeIterVars()
// src_dst: 0 for src_indices, 1 for dst_indices
Array<PrimExpr> MakeIndices(const Array<IterVar> &ivs, int src_dst) const;
PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs,
Array<PrimExpr> extents, int src_dst) const;
Array<PrimExpr> args_;
Buffer src, dst;
Array<Range> src_range, dst_range;
IntImm coalesced_width;
Bool disable_tma = Bool(false);
std::unique_ptr<ParallelOp> par_op_;
int eviction_policy;
};
class Fill : public Operator { class Fill : public Operator {
public: public:
Fill(Array<PrimExpr> args, BufferMap vmap); Fill(Array<PrimExpr> args, BufferMap vmap);
......
...@@ -49,7 +49,6 @@ struct LowerArgs { ...@@ -49,7 +49,6 @@ struct LowerArgs {
AddWorkspaceCallback AddWorkspace; AddWorkspaceCallback AddWorkspace;
LayoutMap layout_map; LayoutMap layout_map;
Map<Buffer, Buffer> buffer_remap; Map<Buffer, Buffer> buffer_remap;
bool disable_tma_lower;
}; };
struct LayoutInferArgs { struct LayoutInferArgs {
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#include <vector> #include <vector>
#include "../op/builtin.h" #include "../op/builtin.h"
#include "../op/bulk_copy.h"
#include "arith/pattern_match.h" #include "arith/pattern_match.h"
#include "target/source/ptx.h" #include "target/source/ptx.h"
...@@ -1100,7 +1099,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1100,7 +1099,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
ss << "tl::tma_store"; ss << "tl::tma_store";
} }
print_extern_call_stmt(ss.str(), 0, 1); print_extern_call_stmt(ss.str(), 0, 1);
} else if (op->op.same_as(tl::ptx_ldmatirx())) { } else if (op->op.same_as(tl::ptx_ldmatrix())) {
int trans = Downcast<IntImm>(op->args[0])->value; int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value; int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num); std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num);
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#include <vector> #include <vector>
#include "../op/builtin.h" #include "../op/builtin.h"
#include "../op/bulk_copy.h"
#include "target/source/ptx.h" #include "target/source/ptx.h"
namespace tvm { namespace tvm {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment