Unverified Commit 8fbe1b3a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Add kernel selection option for GEMM v1 in environment settings (#1200)

* Add kernel selection option for GEMM v1 in environment settings

- Introduced `TILELANG_USE_GEMM_V1` environment variable to control the selection of GEMM version.
- Added `use_gemm_v1` method in the `Environment` class to determine if GEMM v1 should be used based on the environment variable.
- Updated GEMM function assignment to default to v2, allowing for v1 to be forced via the new environment variable.

* bug fix

* Add kernel selection option for GEMM in environment settings

- Introduced `TILELANG_USE_GEMM_V1` environment variable to allow users to select between GEMM v1 and v2 implementations.
- Updated `gemm` function to default to v2 but switch to v1 if the environment variable is set to a truthy value.
- Added a method `use_gemm_v1` in the `Environment` class to facilitate this selection based on the environment variable.

* Refactor GEMM macro generator to use BufferRegion instead of Buffer

- Updated `wgmma` and `wgmma_rs` methods in `TensorCoreIntrinEmitter` to accept `BufferRegion` parameters instead of `Buffer`.
- Adjusted related calls in `GemmWGMMA` to ensure compatibility with the new parameter types.
- Simplified buffer access logic for better clarity and maintainability.

* Refactor GEMM functions to utilize BufferRegion for improved memory handling

- Updated `run_gemm`, `run_gemm_rs`, `run_gemm_sr`, and `run_gemm_rr` functions to set `num_stages` based on block dimensions, enhancing performance for larger matrices.
- Simplified calls to GEMM functions by removing redundant parameters and ensuring compatibility with BufferRegion.
- Introduced utility functions for converting between Buffer, BufferLoad, and BufferRegion, improving code clarity and maintainability.
- Enhanced error handling for full region checks in GEMM operations to ensure correctness in memory access.

* Refactor GEMM code for improved readability and consistency

- Cleaned up formatting and spacing in GEMM-related files for better readability.
- Standardized comments and code structure across various GEMM functions and macros.
- Enhanced error messages for clarity in buffer region checks.
- Removed redundant lines and improved overall code maintainability.

* Update GEMM correctness evaluation and macro generator for improved functionality

- Modified `N_VALUES` in `correctness_evaluation_sm70.py` to include only relevant sizes for tests.
- Updated test function call in `correctness_evaluation.py` to use `test_gemm_false_true` for better accuracy in testing.
- Refactored buffer handling in `mma_sm70_macro_generator.py` to improve clarity and consistency in shared buffer access.
- Enhanced `gemm_mma_sm70.py` to ensure full region checks for input and output buffers, improving correctness in GEMM operations.

* Refactor GEMM and intrinsic files for improved clarity and functionality

- Removed unused variable `A_stride_last` in `mma_sm70_macro_generator.py` to streamline code.
- Adjusted function signature formatting in `swizzle.py` for better readability.
- Restored the return of `GemmWGMMA` in `__init__.py` for correct GEMM instantiation.
- Removed unused variable `B_buf` in `gemm_mma_sm70.py` to enhance code cleanliness.
- Improved function signature formatting in `language.py` for consistency.

* Enhance GEMM and MMA functionality for FP64 support

- Refactored `GemmNode` to streamline the decision-making process for GEMM instruction selection.
- Added support for FP64 inputs in the MMA dispatcher, enabling new tensor operations.
- Introduced a new layout function for FP64 in `mma_layout.py` to facilitate shared memory storage.
- Updated `TensorCoreIntrinEmitter` to handle FP64 data types, including adjustments for micro tile dimensions and loading mechanisms.
- Enhanced utility functions to accommodate FP64 index mapping for shared memory operations.

* lint fix

* Refactor GEMM correctness evaluation and shared memory alignment handling

- Reverted the GEMM function call in `correctness_evaluation.py` to the original implementation for consistency.
- Added a helper function in `merge_shared_memory_allocations.cc` to streamline the marking of shared variables under alignment scope.
- Enhanced the `VisitExpr_` methods to ensure proper handling of shared memory alignment for `BufferLoadNode` and `VarNode` types.
- Cleaned up commented-out test code in `correctness_evaluation.py` for better readability.

* Enhance GEMM and MMA implementations with region-based memory handling

- Updated GEMM and MMA classes to utilize BufferRegion for input and output buffers, improving memory management and supporting strided GEMM operations.
- Added checks to ensure full region compliance for input buffers, enhancing correctness in matrix multiplication.
- Implemented clear accumulation functionality to reset output buffers before accumulation, ensuring accurate results in GEMM operations.

* Refactor test_tilelang_example_deepseek_v32.py to improve import structure and function calls

- Updated import statements to directly reference modules instead of individual test functions, enhancing clarity.
- Modified function calls to use the new module structure for better organization and maintainability in testing examples.

* Enhance OnArrayDeclaration method to handle repeated buffer declarations

- Updated the OnArrayDeclaration method to merge metadata for buffers that may appear in multiple Allocate statements, improving robustness against upstream transformations.
- Added logic to prefer concrete element data types and record extents when previously unknown, enhancing the handling of buffer declarations.

* Add abbreviation for bfloat16 data type in mfma_macro_generator.py

- Introduced a new abbreviation "bf16" for the bfloat16 data type in the mfma_macro_generator.py file, enhancing clarity and consistency in data type representation.

* Refactor CodeGenTileLangHIP to enhance dtype handling and mfma call generation

- Introduced a mapping function to normalize input data types to their corresponding scalar types, improving compatibility with MfmaTraits.
- Updated the mfma call generation to utilize the new mapping, streamlining the code and enhancing clarity.
- Removed outdated dtype mapping and replaced it with a more flexible approach to support additional data types like FP8.

* lint fix

* Enhance backend configuration in CMakeLists.txt and improve dtype handling in CodeGenTileLangHIP

- Introduced a macro to define backend options for CUDA, ROCM, and Metal, allowing user overrides and caching of settings.
- Updated logic to track user-selected backends and conditionally enable defaults based on environment variables.
- Refactored dtype handling in CodeGenTileLangHIP to streamline mfma call generation and improve clarity.
- Added support for bfloat16 in the mfma_macro_generator.py, enhancing data type representation consistency.

* Update bfloat16 handling in CodeGenTileLangHIP and mfma_macro_generator.py

- Changed the representation of bfloat16 in CodeGenTileLangHIP from "bfloat16x4" to "bfloat16x4_vec" for improved clarity.
- Adjusted the mfma_suffix generation in mfma_macro_generator.py to remove the underscore before "bf16", aligning with HIP intrinsic requirements.

* Change logging level from WARNING to DLOG in LegalizeNegativeIndex for non-negative index checks to reduce log verbosity.

* Refactor attention sink examples to simplify index calculations

- Updated index handling in `example_gqa_sink_bwd_bhsd.py` and `example_mha_sink_bwd_bhsd.py` to eliminate unnecessary local allocations and streamline logic for determining start and end indices.
- Improved readability by using direct calculations instead of local variables for index bounds in pipelined loops.

* Refactor attention sink examples to streamline index calculations

- Simplified index handling in `example_gqa_sink_bwd_bhsd.py`, `example_gqa_sink_fwd_bhsd_wgmma_pipelined.py`, `example_mha_sink_bwd_bhsd.py`, `example_mha_sink_fwd_bhsd_wgmma_pipelined.py`, and `example_mha_sink_fwd_bhsd.py` by removing unnecessary local allocations for start and end indices.
- Enhanced readability by directly calculating index bounds for pipelined loops, improving overall code clarity.

* lint fix

* bugfix

* Refactor reduce operation handling in CUDA and Python

- Removed outdated shared memory reduction logic from `reduce.cc`.
- Introduced fragment allocation and improved buffer handling in `reduce.py` to support shared and fragment scopes.
- Updated CUDA header to define a wider accumulator type for better numerical accuracy.
- Enhanced error handling for buffer scope validation in the reduction process.

* Fix ReduceOpNode to correctly compute AbsMax by using absolute values of inputs

* Enhance unit loop handling by refining annotation checks

- Updated the condition for identifying effectively empty annotations in unit loops to include cases where only the `pragma_unroll_explicit` hint is present.
- Introduced a new method, `IsEffectivelyEmptyAnnotation`, to encapsulate this logic, improving code clarity and maintainability.

* clean clode
parent 2b1f5990
...@@ -928,7 +928,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -928,7 +928,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
{"float32", "float"}, {"float32", "float"},
{"float64", "double"}, {"float64", "double"},
{"float16x4", "float16x4"}, {"float16x4", "float16x4"},
{"bfloat16x4", "bfloat16x4"}, {"bfloat16x4", "bfloat16x4_vec"},
{"float32x4", "float32x4"}, {"float32x4", "float32x4"},
{"float8_e4m3fnuzx4", "fp8_e4_4_t"}, {"float8_e4m3fnuzx4", "fp8_e4_4_t"},
{"float8_e4m3fnuzx8", "long"}, {"float8_e4m3fnuzx8", "long"},
......
...@@ -136,6 +136,10 @@ TL_DEFINE_MMA_DISPATCHER(kTensorFloat32, kTensorFloat32, kFloat32, 16, 8, 8, ...@@ -136,6 +136,10 @@ TL_DEFINE_MMA_DISPATCHER(kTensorFloat32, kTensorFloat32, kFloat32, 16, 8, 8,
false, true, false, false, true, false,
cute::SM80_16x8x8_F32TF32TF32F32_TN) cute::SM80_16x8x8_F32TF32TF32F32_TN)
// FP64 inputs (DMMA: m8n8k4, TN layout)
TL_DEFINE_MMA_DISPATCHER(kFloat64, kFloat64, kFloat64, 8, 8, 4, false, true,
false, cute::SM80_8x8x4_F64F64F64F64_TN)
#undef TL_DEFINE_MMA_DISPATCHER #undef TL_DEFINE_MMA_DISPATCHER
} // namespace detail } // namespace detail
......
#pragma once #pragma once
#include "common.h" #include "common.h"
#include <cstdint>
#include <type_traits>
namespace tl { namespace tl {
// Select a wider accumulator type for improved numerical accuracy.
// Default: accumulate in the same type. Specialize FP16/BF16 to float.
template <typename T> struct AccType {
using type = T;
};
template <> struct AccType<half_t> {
using type = float;
};
template <> struct AccType<bfloat16_t> {
using type = float;
};
struct SumOp { struct SumOp {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) { template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return x + y; return x + y;
...@@ -40,53 +54,6 @@ struct BitXorOp { ...@@ -40,53 +54,6 @@ struct BitXorOp {
} }
}; };
template <class Reducer, int Threads, bool UseAbs, bool NeedAccumulate>
struct SharedReduceWarp {
template <typename T>
static TL_DEVICE void run(const T *__restrict__ src, T *__restrict__ dst,
int total_dest, int reduce_extent, int tail,
T init_value) {
if (total_dest <= 0 || reduce_extent <= 0)
return;
constexpr int kWarpSize = 32;
static_assert(Threads % kWarpSize == 0,
"SharedReduceWarp expects blockDim.x to be a multiple of "
"warp size on CUDA.");
const int tid = threadIdx.x;
const int warp_id = tid / kWarpSize;
const int lane = tid % kWarpSize;
const int num_warps = Threads / kWarpSize;
for (int dest_idx = warp_id; dest_idx < total_dest; dest_idx += num_warps) {
const int prefix = tail == 1 ? dest_idx : dest_idx / tail;
const int suffix = tail == 1 ? 0 : dest_idx % tail;
const int src_base = (prefix * reduce_extent) * tail + suffix;
const int dst_index = prefix * tail + suffix;
T partial = init_value;
for (int rv = lane; rv < reduce_extent; rv += kWarpSize) {
T val = src[src_base + rv * tail];
if constexpr (UseAbs) {
val = val < T(0) ? -val : val;
}
partial = Reducer()(partial, val);
}
unsigned mask = __activemask();
for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) {
T other = tl::shfl_down_sync(mask, partial, offset);
partial = Reducer()(partial, other);
}
if (lane == 0) {
if constexpr (NeedAccumulate) {
partial = Reducer()(dst[dst_index], partial);
}
dst[dst_index] = partial;
}
}
}
};
template <class Reducer, int threads, int scale, int thread_offset = 0, template <class Reducer, int threads, int scale, int thread_offset = 0,
int all_threads = threads> int all_threads = threads>
struct AllReduce { struct AllReduce {
......
...@@ -123,7 +123,7 @@ public: ...@@ -123,7 +123,7 @@ public:
states.push_back(IndexSignState::kUnknown); states.push_back(IndexSignState::kUnknown);
needs_record = true; needs_record = true;
LOG(WARNING) << "LegalizeNegativeIndex: cannot prove non-negative index " DLOG(WARNING) << "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << load->buffer->name << simplified << " for buffer " << load->buffer->name
<< " (axis " << i << ")."; << " (axis " << i << ").";
} }
......
...@@ -119,7 +119,7 @@ private: ...@@ -119,7 +119,7 @@ private:
// Step 1. Update unit loop info. // Step 1. Update unit loop info.
PrimExpr min = this->VisitExpr(op->min); PrimExpr min = this->VisitExpr(op->min);
PrimExpr extent = this->VisitExpr(op->extent); PrimExpr extent = this->VisitExpr(op->extent);
if (is_one(extent) && op->annotations.empty()) { if (is_one(extent) && IsEffectivelyEmptyAnnotation(op->annotations)) {
// handling unit loop // handling unit loop
unit_loop_vars_[op->loop_var] = min; unit_loop_vars_[op->loop_var] = min;
} }
...@@ -135,7 +135,8 @@ private: ...@@ -135,7 +135,8 @@ private:
ICHECK(op->thread_binding.defined()); ICHECK(op->thread_binding.defined());
String thread_tag = op->thread_binding.value()->thread_tag; String thread_tag = op->thread_binding.value()->thread_tag;
body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body); body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body);
} else if (is_one(extent) && op->annotations.empty()) { } else if (is_one(extent) &&
IsEffectivelyEmptyAnnotation(op->annotations)) {
// Case 2. Unit loop // Case 2. Unit loop
return body; return body;
} else { } else {
...@@ -150,6 +151,23 @@ private: ...@@ -150,6 +151,23 @@ private:
return body; return body;
} }
// Treat annotations as empty if they are truly empty or contain only
// the unroll hint `pragma_unroll_explicit`. This allows unit-length
// loops produced by unroll pragmas to be simplified away.
bool
IsEffectivelyEmptyAnnotation(const Map<String, ffi::Any> &annotations) const {
if (annotations.empty()) {
return true;
}
if (annotations.size() == 1) {
auto it = annotations.find(tir::attr::pragma_unroll_explicit);
if (it != annotations.end()) {
return true;
}
}
return false;
}
PrimExpr VisitExpr_(const VarNode *op) final { PrimExpr VisitExpr_(const VarNode *op) final {
Var var = tvm::ffi::GetRef<Var>(op); Var var = tvm::ffi::GetRef<Var>(op);
auto it = unit_loop_vars_.find(var); auto it = unit_loop_vars_.find(var);
......
...@@ -104,55 +104,6 @@ private: ...@@ -104,55 +104,6 @@ private:
Map<Buffer, Layout> layout_remap_; Map<Buffer, Layout> layout_remap_;
}; };
class BufferGemmCollector : public StmtExprVisitor {
public:
BufferGemmCollector() { Clear(); }
void Clear() { buffer_var_gemm_.clear(); }
void Collect(const Stmt &stmt) { VisitStmt(stmt); }
Array<Var> GetBufferVarGemm() { return buffer_var_gemm_; }
private:
void VisitStmt_(const EvaluateNode *op) {
const CallNode *call_node = op->value.as<CallNode>();
// Value of EvaluateNode may not be a call
if (!call_node) {
return;
}
auto call = Downcast<Call>(call_node);
if (call->op.same_as(Gemm::Get())) {
auto srcA_buffer_access_ptr = Downcast<Call>(call->args[0]);
ICHECK(srcA_buffer_access_ptr->op.same_as(builtin::tvm_access_ptr()));
auto srcA_buffer_var = Downcast<Var>(srcA_buffer_access_ptr->args[1]);
auto srcB_buffer_access_ptr = Downcast<Call>(call->args[1]);
ICHECK(srcB_buffer_access_ptr->op.same_as(builtin::tvm_access_ptr()));
auto srcB_buffer_var = Downcast<Var>(srcB_buffer_access_ptr->args[1]);
auto dst_buffer_access_ptr = Downcast<Call>(call->args[2]);
ICHECK(dst_buffer_access_ptr->op.same_as(builtin::tvm_access_ptr()));
auto dst_buffer_var = Downcast<Var>(dst_buffer_access_ptr->args[1]);
buffer_var_gemm_.push_back(srcA_buffer_var);
buffer_var_gemm_.push_back(srcB_buffer_var);
buffer_var_gemm_.push_back(dst_buffer_var);
} else if (call->op.same_as(GemmSP::Get())) {
auto srcA_buffer_access_ptr = Downcast<Call>(call->args[0]);
ICHECK(srcA_buffer_access_ptr->op.same_as(builtin::tvm_access_ptr()));
auto srcA_buffer_var = Downcast<Var>(srcA_buffer_access_ptr->args[1]);
auto srcB_buffer_access_ptr = Downcast<Call>(call->args[1]);
ICHECK(srcB_buffer_access_ptr->op.same_as(builtin::tvm_access_ptr()));
auto srcB_buffer_var = Downcast<Var>(srcB_buffer_access_ptr->args[1]);
auto dst_buffer_access_ptr = Downcast<Call>(call->args[2]);
ICHECK(dst_buffer_access_ptr->op.same_as(builtin::tvm_access_ptr()));
auto dst_buffer_var = Downcast<Var>(dst_buffer_access_ptr->args[1]);
buffer_var_gemm_.push_back(srcA_buffer_var);
buffer_var_gemm_.push_back(srcB_buffer_var);
buffer_var_gemm_.push_back(dst_buffer_var);
}
}
Array<Var> buffer_var_gemm_;
};
/*! /*!
* \brief A class that rewrites buffer references in a statement based on a * \brief A class that rewrites buffer references in a statement based on a
...@@ -254,11 +205,6 @@ public: ...@@ -254,11 +205,6 @@ public:
auto target = f->GetAttr<Target>(tvm::attr::kTarget); auto target = f->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined()) << "LowerTileOpPass: Require the target attribute"; ICHECK(target.defined()) << "LowerTileOpPass: Require the target attribute";
substituter.target_ = target.value(); substituter.target_ = target.value();
// For TMA 1D, we should collect the buffers which are not used in GEMM and
// do not need swizzle
BufferGemmCollector collector;
collector.Collect(f->body);
substituter.buffer_var_gemm_ = collector.GetBufferVarGemm();
PrimFuncNode *fptr = f.CopyOnWrite(); PrimFuncNode *fptr = f.CopyOnWrite();
fptr->body = substituter.VisitStmt(f->body); fptr->body = substituter.VisitStmt(f->body);
fptr->body = fptr->body =
...@@ -693,9 +639,9 @@ private: ...@@ -693,9 +639,9 @@ private:
thread_bounds = Range::FromMinExtent(0, 1); thread_bounds = Range::FromMinExtent(0, 1);
} }
auto lowered = tile_op->Lower( auto lowered =
LowerArgs{target_, thread_bounds, thread_var_->var, callback, tile_op->Lower(LowerArgs{target_, thread_bounds, thread_var_->var,
layout_map_, buffer_remap_, buffer_var_gemm_}, callback, layout_map_, buffer_remap_},
analyzer_); analyzer_);
return IRMutatorWithAnalyzer::VisitStmt(lowered); return IRMutatorWithAnalyzer::VisitStmt(lowered);
} }
...@@ -734,7 +680,6 @@ private: ...@@ -734,7 +680,6 @@ private:
std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_; std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_;
Map<Var, Var> var_remap_; Map<Var, Var> var_remap_;
bool has_tma_{false}; bool has_tma_{false};
Array<Var> buffer_var_gemm_;
}; };
namespace transform { namespace transform {
......
...@@ -354,11 +354,28 @@ public: ...@@ -354,11 +354,28 @@ public:
} }
private: private:
// Helper to record alignment for a shared/shared.dyn Var under alignment
// scope
void MarkSharedVarIfNeeded(const VarNode *op) {
if (!op || !under_alignment_scope_)
return;
auto ptr_type = op->type_annotation.as<PointerTypeNode>();
if (!ptr_type)
return;
auto scope = GetPtrStorageScope(tvm::ffi::GetRef<Var>(op));
if (scope == "shared" || scope == "shared.dyn") {
auto target = Target::Current();
ICHECK(target.defined()) << "Target is not defined";
const int alignment = TargetIsHopper(target) ? 1024 : 16;
shmem_alignment_map_[op] = alignment;
}
}
void VisitExpr_(const CallNode *op) { void VisitExpr_(const CallNode *op) {
if (op->op.same_as(tl::tl_gemm()) || op->op.same_as(tl::tl_gemm_sp()) || if (op->op.same_as(tl::tl_gemm()) || op->op.same_as(tl::tl_gemm_sp()) ||
op->op.same_as(tl::tma_load()) || op->op.same_as(tl::tma_store()) || op->op.same_as(tl::tma_load()) || op->op.same_as(tl::tma_store()) ||
op->op.same_as(tl::ptx_wgmma_ss()) || op->op.same_as(tl::initialize_wgmma_descriptor()) ||
op->op.same_as(tl::ptx_wgmma_rs())) { op->op.same_as(tl::initialize_tcgen05_descriptor())) {
// These intrinsics introduce stricter SMEM alignment requirements; mark // These intrinsics introduce stricter SMEM alignment requirements; mark
// the subtree. // the subtree.
under_alignment_scope_ = true; under_alignment_scope_ = true;
...@@ -370,15 +387,16 @@ private: ...@@ -370,15 +387,16 @@ private:
} }
void VisitExpr_(const VarNode *op) { void VisitExpr_(const VarNode *op) {
auto ptr_type = op->type_annotation.as<PointerTypeNode>(); MarkSharedVarIfNeeded(op);
if (ptr_type && under_alignment_scope_) { StmtExprVisitor::VisitExpr_(op);
auto scope = GetPtrStorageScope(tvm::ffi::GetRef<Var>(op));
if (scope == "shared" || scope == "shared.dyn") {
auto target = Target::Current();
ICHECK(target.defined()) << "Target is not defined";
const int alignment = TargetIsHopper(target) ? 1024 : 16;
shmem_alignment_map_[op] = alignment;
} }
void VisitExpr_(const BufferLoadNode *op) {
// If we encounter address_of(BufferLoad(...)) or any direct BufferLoad
// within an alignment scope, make sure we mark the underlying shared var.
if (op && under_alignment_scope_) {
const VarNode *data_var = op->buffer->data.get();
MarkSharedVarIfNeeded(data_var);
} }
StmtExprVisitor::VisitExpr_(op); StmtExprVisitor::VisitExpr_(op);
} }
......
...@@ -1425,9 +1425,30 @@ public: ...@@ -1425,9 +1425,30 @@ public:
void void
OnArrayDeclaration(const Var &buffer, DataType element_dtype, PrimExpr extent, OnArrayDeclaration(const Var &buffer, DataType element_dtype, PrimExpr extent,
BufferVarInfo::DeclarationLocation declaration_location) { BufferVarInfo::DeclarationLocation declaration_location) {
ICHECK(info_map_.find(buffer.get()) == info_map_.end()) auto it = info_map_.find(buffer.get());
<< "Array declaration of " << buffer->name_hint if (it != info_map_.end()) {
<< " occurred multiple times."; // The same buffer var may appear in more than one Allocate due to
// upstream transforms (e.g., storage planning/merging). Treat repeated
// declarations as benign and merge metadata instead of erroring.
BufferVarInfo &existing = it->second;
// Prefer a concrete element dtype if the previous one was a handle.
if (existing.element_dtype.is_handle() && !element_dtype.is_handle()) {
existing.element_dtype =
element_dtype == DataType::Bool()
? DataType::Int(8).with_lanes(element_dtype.lanes())
: element_dtype;
}
// If extent was previously unknown (0) and a concrete extent is
// provided now, record it.
if (!existing.extent.defined() || is_zero(existing.extent)) {
existing.extent = extent;
}
// Merge declaration locations (bitwise OR of flags).
existing.declaration_location =
static_cast<BufferVarInfo::DeclarationLocation>(
existing.declaration_location | declaration_location);
return;
}
if (element_dtype == DataType::Bool()) { if (element_dtype == DataType::Bool()) {
element_dtype = DataType::Int(8).with_lanes(element_dtype.lanes()); element_dtype = DataType::Int(8).with_lanes(element_dtype.lanes());
......
...@@ -514,4 +514,5 @@ def test_assert_tl_matmul_block_all_dynamic_with_pass_config(): ...@@ -514,4 +514,5 @@ def test_assert_tl_matmul_block_all_dynamic_with_pass_config():
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() # tilelang.testing.main()
assert_tl_matmul_macro_correctness(128, 128, 128, "float16", "float16", "float16")
...@@ -236,6 +236,10 @@ class Environment: ...@@ -236,6 +236,10 @@ class Environment:
"1") # print kernel name on compile "1") # print kernel name on compile
TILELANG_CLEAR_CACHE = EnvVar("TILELANG_CLEAR_CACHE", "0") # clear cache automatically if set TILELANG_CLEAR_CACHE = EnvVar("TILELANG_CLEAR_CACHE", "0") # clear cache automatically if set
# Kernel selection options
# Default to GEMM v2; set to "1"/"true"/"yes"/"on" to force v1
TILELANG_USE_GEMM_V1 = EnvVar("TILELANG_USE_GEMM_V1", "0")
# Auto-tuning settings # Auto-tuning settings
TILELANG_AUTO_TUNING_CPU_UTILITIES = EnvVar("TILELANG_AUTO_TUNING_CPU_UTILITIES", TILELANG_AUTO_TUNING_CPU_UTILITIES = EnvVar("TILELANG_AUTO_TUNING_CPU_UTILITIES",
"0.9") # percent of CPUs used "0.9") # percent of CPUs used
...@@ -274,6 +278,14 @@ class Environment: ...@@ -274,6 +278,14 @@ class Environment:
def is_print_on_compilation_enabled(self) -> bool: def is_print_on_compilation_enabled(self) -> bool:
return self.TILELANG_PRINT_ON_COMPILATION.lower() in ("1", "true", "yes", "on") return self.TILELANG_PRINT_ON_COMPILATION.lower() in ("1", "true", "yes", "on")
def use_gemm_v1(self) -> bool:
"""Return True if GEMM v1 should be used based on env.
Controlled by `TILELANG_USE_GEMM_V1`. Truthy values are one of
{"1", "true", "yes", "on"} (case-insensitive).
"""
return str(self.TILELANG_USE_GEMM_V1).lower() in ("1", "true", "yes", "on")
# Instantiate as a global configuration object # Instantiate as a global configuration object
env = Environment() env = Environment()
......
...@@ -2,14 +2,14 @@ from __future__ import annotations ...@@ -2,14 +2,14 @@ from __future__ import annotations
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.language as T import tilelang.language as T
from tvm import DataType from tvm import DataType
from tvm.tir import PrimExpr, IndexMap, Buffer, Var from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion
from tvm.runtime import convert from tvm.runtime import convert
from .utils import ( from .utils import (
mfma_store_index_map,) mfma_store_index_map,)
from typing import Literal, Callable from typing import Literal, Callable
from tilelang.utils import is_fragment from tilelang.utils import is_fragment
from tilelang.utils.language import to_buffer_region
from .mfma_layout import ( from .mfma_layout import (
shared_16x4_to_local_64x1_layout_A, shared_16x4_to_local_64x1_layout_A,
shared_4x16_to_local_64x1_layout_B, shared_4x16_to_local_64x1_layout_B,
...@@ -139,6 +139,7 @@ class MatrixCoreIntrinEmitter: ...@@ -139,6 +139,7 @@ class MatrixCoreIntrinEmitter:
}[out_dtype] }[out_dtype]
in_dtype_abbrv = { in_dtype_abbrv = {
"bfloat16": "bf16",
"float16": "f16", "float16": "f16",
"float32": "f32", "float32": "f32",
"int8": "i8", "int8": "i8",
...@@ -150,6 +151,9 @@ class MatrixCoreIntrinEmitter: ...@@ -150,6 +151,9 @@ class MatrixCoreIntrinEmitter:
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_fp8_fp8" self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_fp8_fp8"
elif in_dtype_abbrv == "i8": elif in_dtype_abbrv == "i8":
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_i8" self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_i8"
elif in_dtype_abbrv == "bf16":
# HIP intrinsic uses ...x{K}bf16_1k without an underscore before bf16
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}bf16_1k"
else: else:
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}" self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}"
...@@ -251,7 +255,7 @@ class MatrixCoreIntrinEmitter: ...@@ -251,7 +255,7 @@ class MatrixCoreIntrinEmitter:
(WARP_SIZE * block_row_warps)) % block_col_warps, (WARP_SIZE * block_row_warps)) % block_col_warps,
return lane_id, warp_n, warp_m return lane_id, warp_n, warp_m
def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, rk=0): def ldmatrix_a(self, A_local_buf, A_shared_buf: Buffer | BufferRegion, ki, rk=0):
warp_row_tiles = self.warp_row_tiles warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows warp_rows = self.warp_rows
chunk = self.chunk chunk = self.chunk
...@@ -263,6 +267,12 @@ class MatrixCoreIntrinEmitter: ...@@ -263,6 +267,12 @@ class MatrixCoreIntrinEmitter:
thread_binding = self.get_thread_binding() thread_binding = self.get_thread_binding()
_, reverse_index_map = self.get_ldmatrix_index_map(is_b=False) _, reverse_index_map = self.get_ldmatrix_index_map(is_b=False)
# legalize shared buffer to region
A_region = to_buffer_region(A_shared_buf)
A_buf = A_region.buffer
A_base0 = A_region.region[-2].min
A_base1 = A_region.region[-1].min
@T.macro @T.macro
def _warp_ldmatrix_a( def _warp_ldmatrix_a(
A_local_buf, A_local_buf,
...@@ -278,20 +288,20 @@ class MatrixCoreIntrinEmitter: ...@@ -278,20 +288,20 @@ class MatrixCoreIntrinEmitter:
row, col = T.meta_var(reverse_index_map(tx, local_id)) row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (rk * chunk + ki * (k_pack * micro_size_k), l, r = (rk * chunk + ki * (k_pack * micro_size_k),
warp_m * warp_row_tiles + i * micro_size_x) warp_m * warp_row_tiles + i * micro_size_x)
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row, A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row,
r + col] A_base1 + r + col]
else: else:
for i in T.serial(warp_rows): for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a): for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id)) row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (warp_m * warp_row_tiles + i * micro_size_x, l, r = (warp_m * warp_row_tiles + i * micro_size_x,
rk * chunk + ki * (k_pack * micro_size_k)) rk * chunk + ki * (k_pack * micro_size_k))
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row, A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row,
r + col] A_base1 + r + col]
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, rk=0): def ldmatrix_b(self, B_local_buf, B_shared_buf: Buffer | BufferRegion, ki, rk=0):
warp_col_tiles = self.warp_col_tiles warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols warp_cols = self.warp_cols
chunk = self.chunk chunk = self.chunk
...@@ -303,6 +313,12 @@ class MatrixCoreIntrinEmitter: ...@@ -303,6 +313,12 @@ class MatrixCoreIntrinEmitter:
thread_binding = self.get_thread_binding() thread_binding = self.get_thread_binding()
_, reverse_index_map = self.get_ldmatrix_index_map(is_b=True) _, reverse_index_map = self.get_ldmatrix_index_map(is_b=True)
# legalize shared buffer to region
B_region = to_buffer_region(B_shared_buf)
B_buf = B_region.buffer
B_base0 = B_region.region[-2].min
B_base1 = B_region.region[-1].min
@T.macro @T.macro
def _warp_ldmatrix_b( def _warp_ldmatrix_b(
B_local_buf, B_local_buf,
...@@ -320,8 +336,8 @@ class MatrixCoreIntrinEmitter: ...@@ -320,8 +336,8 @@ class MatrixCoreIntrinEmitter:
warp_n * warp_col_tiles + j * micro_size_y, warp_n * warp_col_tiles + j * micro_size_y,
rk * chunk + ki * (k_pack * micro_size_k), rk * chunk + ki * (k_pack * micro_size_k),
) )
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row, B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row,
r + col] B_base1 + r + col]
else: else:
for j in T.serial(warp_cols): for j in T.serial(warp_cols):
...@@ -331,8 +347,8 @@ class MatrixCoreIntrinEmitter: ...@@ -331,8 +347,8 @@ class MatrixCoreIntrinEmitter:
rk * chunk + ki * (k_pack * micro_size_k), rk * chunk + ki * (k_pack * micro_size_k),
warp_n * warp_col_tiles + j * micro_size_y, warp_n * warp_col_tiles + j * micro_size_y,
) )
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row, B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row,
r + col] B_base1 + r + col]
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
......
...@@ -45,6 +45,12 @@ def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id): ...@@ -45,6 +45,12 @@ def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id):
return row, col return row, col
def mma_store_32x2_to_shared_8x8_layout_fp64(thread_id, local_id):
row = thread_id // 4
col = (thread_id % 4) * 2 + local_id
return row, col
# sr represents spatial + reduction layout # sr represents spatial + reduction layout
# the first axis is spatial while the second axis is reduction # the first axis is spatial while the second axis is reduction
# mma.sync matrix A layout, if wanna trans, please apply map_indices # mma.sync matrix A layout, if wanna trans, please apply map_indices
......
...@@ -3,13 +3,14 @@ import tilelang.language as T ...@@ -3,13 +3,14 @@ import tilelang.language as T
from typing import Literal, Callable from typing import Literal, Callable
from tilelang.common import TransformKind from tilelang.common import TransformKind
from tvm import DataType from tvm import DataType
from tvm.tir import PrimExpr, IndexMap, Buffer, Var from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion
from tilelang import tvm as tvm
from tvm.runtime import convert from tvm.runtime import convert
from .utils import ( from .utils import (
mma_store_index_map, mma_store_index_map,
get_ldmatrix_offset, get_ldmatrix_offset,
) )
from tilelang.utils import is_fragment from tilelang.utils import is_fragment, to_buffer_region
from tilelang.intrinsics.mma_layout import ( from tilelang.intrinsics.mma_layout import (
shared_16x8_to_mma_32x4_layout_sr_a, shared_16x8_to_mma_32x4_layout_sr_a,
shared_16x8_to_mma_32x4_layout_sr_b, shared_16x8_to_mma_32x4_layout_sr_b,
...@@ -40,6 +41,7 @@ class TensorCoreIntrinEmitter: ...@@ -40,6 +41,7 @@ class TensorCoreIntrinEmitter:
"float16": "fp16", "float16": "fp16",
"bfloat16": "bf16", "bfloat16": "bf16",
"float32": "fp32", "float32": "fp32",
"float64": "fp64",
"int8": "int8", "int8": "int8",
"int32": "int32", "int32": "int32",
"float8_e4m3": "e4m3", "float8_e4m3": "e4m3",
...@@ -78,6 +80,11 @@ class TensorCoreIntrinEmitter: ...@@ -78,6 +80,11 @@ class TensorCoreIntrinEmitter:
self.warp_col_tiles = warp_col_tiles self.warp_col_tiles = warp_col_tiles
self.chunk = chunk self.chunk = chunk
self._initialize_k_dim(a_dtype) self._initialize_k_dim(a_dtype)
# For FP64, MMA shape is m8n8k4; adjust instance dims early
if DataType(a_dtype).bits == 64:
# Override default M/N dims for fp64 MMA
self.M_DIM = 8
# n_dim will be set to 8 in _initialize_micro_size via k_dim==4
self._initialize_abbrev(a_dtype, b_dtype, accum_dtype) self._initialize_abbrev(a_dtype, b_dtype, accum_dtype)
self._initialize_micro_size(self.M_DIM, self.k_dim) self._initialize_micro_size(self.M_DIM, self.k_dim)
self._initialize_local_size(self.M_DIM, self.n_dim, self.k_dim, self.WARP_SIZE) self._initialize_local_size(self.M_DIM, self.n_dim, self.k_dim, self.WARP_SIZE)
...@@ -116,7 +123,10 @@ class TensorCoreIntrinEmitter: ...@@ -116,7 +123,10 @@ class TensorCoreIntrinEmitter:
raise ValueError(f"Unsupported dtype: {dtype}") from err raise ValueError(f"Unsupported dtype: {dtype}") from err
def _initialize_mma_prefix(self, k_dim: int = 16): def _initialize_mma_prefix(self, k_dim: int = 16):
if k_dim == 8: if k_dim == 4:
# fp64
self.mma_prefix = "m8n8k4"
elif k_dim == 8:
# typically used for tfloat32 # typically used for tfloat32
self.mma_prefix = "m16n8k8" self.mma_prefix = "m16n8k8"
elif k_dim == 16: elif k_dim == 16:
...@@ -131,6 +141,15 @@ class TensorCoreIntrinEmitter: ...@@ -131,6 +141,15 @@ class TensorCoreIntrinEmitter:
def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16): def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16):
warp_row_tiles = self.warp_row_tiles warp_row_tiles = self.warp_row_tiles
warp_col_tiles = self.warp_col_tiles warp_col_tiles = self.warp_col_tiles
# For fp64 (k_dim==4), micro tile is 8x8, otherwise keep 16x{8|16}
if k_dim == 4:
# fp64 path: m_dim must be 8, n_dim 8
assert m_dim == 8, f"For fp64 MMA, m_dim must be 8, got {m_dim}"
self.n_dim = 8
self.micro_size_y = 8
self.warp_rows = warp_row_tiles // m_dim
self.warp_cols = warp_col_tiles // 8
else:
assert warp_row_tiles >= 16, f"warp_row_tiles must be greater than 16, got {warp_row_tiles}" assert warp_row_tiles >= 16, f"warp_row_tiles must be greater than 16, got {warp_row_tiles}"
assert warp_row_tiles % 16 == 0, f"warp_row_tiles must be divisible by 16, got {warp_row_tiles}" assert warp_row_tiles % 16 == 0, f"warp_row_tiles must be divisible by 16, got {warp_row_tiles}"
assert warp_col_tiles >= 8, f"warp_col_tiles must be greater than 8, got {warp_col_tiles}" assert warp_col_tiles >= 8, f"warp_col_tiles must be greater than 8, got {warp_col_tiles}"
...@@ -164,7 +183,11 @@ class TensorCoreIntrinEmitter: ...@@ -164,7 +183,11 @@ class TensorCoreIntrinEmitter:
return self.thread_var return self.thread_var
def get_store_index_map(self, inverse: bool = False) -> IndexMap: def get_store_index_map(self, inverse: bool = False) -> IndexMap:
from .utils import mma_store_index_map, mma_store_index_map_fp64
warp_size, local_size_c = self.WARP_SIZE, self.local_size_out warp_size, local_size_c = self.WARP_SIZE, self.local_size_out
if DataType(self.accum_dtype).bits == 64:
index_map = IndexMap.from_func(mma_store_index_map_fp64, index_dtype="int32")
else:
index_map = IndexMap.from_func(mma_store_index_map, index_dtype="int32") index_map = IndexMap.from_func(mma_store_index_map, index_dtype="int32")
if not inverse: if not inverse:
return index_map return index_map
...@@ -205,9 +228,47 @@ class TensorCoreIntrinEmitter: ...@@ -205,9 +228,47 @@ class TensorCoreIntrinEmitter:
def ldmatrix_a(self, def ldmatrix_a(self,
A_local_buf: Buffer, A_local_buf: Buffer,
A_shared_buf: Buffer, A_shared_buf: Buffer | BufferRegion,
ki: PrimExpr, ki: PrimExpr,
rk: PrimExpr | None = 0): rk: PrimExpr | None = 0):
# Fast path for fp64: no ldmatrix support, do direct per-lane loads
if DataType(self.a_dtype).bits == 64:
warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows
chunk = self.chunk
micro_size_x = self.micro_size_x # 8
micro_size_k = self.micro_size_k # 4
local_size_a = self.local_size_a # 1
a_transposed = self.a_transposed
thread_binding = self.get_thread_binding()
# legalize shared buffer to region
A_region = to_buffer_region(A_shared_buf)
A_buf = A_region.buffer
A_base0 = A_region.region[-2].min
A_base1 = A_region.region[-1].min
@T.macro
def _warp_ld_a_fp64(
A_local_buf,
A_shared_buf,
ki,
thread_binding,
rk=0,
):
tx, _, warp_m = self.extract_thread_binding(thread_binding)
for i in T.serial(warp_rows):
wi = warp_m * warp_row_tiles + i * micro_size_x
wk = rk * chunk + ki * micro_size_k
mi = tx // micro_size_k
mk = tx % micro_size_k
if a_transposed:
A_local_buf[i * local_size_a] = A_buf[A_base0 + wk + mk, A_base1 + wi + mi]
else:
A_local_buf[i * local_size_a] = A_buf[A_base0 + wi + mi, A_base1 + wk + mk]
return _warp_ld_a_fp64(A_local_buf, A_region, ki, thread_binding, rk)
warp_row_tiles = self.warp_row_tiles warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows warp_rows = self.warp_rows
chunk = self.chunk chunk = self.chunk
...@@ -232,6 +293,13 @@ class TensorCoreIntrinEmitter: ...@@ -232,6 +293,13 @@ class TensorCoreIntrinEmitter:
thread_binding = self.get_thread_binding() thread_binding = self.get_thread_binding()
# legalize shared buffer to region
A_region = to_buffer_region(A_shared_buf)
A_buf = A_region.buffer
A_base0 = A_region.region[-2].min
A_base1 = A_region.region[-1].min
A_stride_last = A_buf.shape[-1]
@T.macro @T.macro
def _warp_ldmatrix_a( def _warp_ldmatrix_a(
A_local_buf, A_local_buf,
...@@ -240,14 +308,16 @@ class TensorCoreIntrinEmitter: ...@@ -240,14 +308,16 @@ class TensorCoreIntrinEmitter:
thread_binding, thread_binding,
rk=0, rk=0,
): ):
stride = A_shared_buf.shape[-1] stride = A_stride_last
tx, _, warp_m = self.extract_thread_binding(thread_binding) tx, _, warp_m = self.extract_thread_binding(thread_binding)
trans = self.a_transposed trans = self.a_transposed
for i in T.serial(warp_rows): for i in T.serial(warp_rows):
# Assign A_shared_buf_elem # Assign A_shared_buf_elem
wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k
A_shared_buf_elem = A_shared_buf[wk, wi] if a_transposed else A_shared_buf[wi, wk] A_shared_buf_elem = A_buf[A_base0 + wk,
A_base1 + wi] if a_transposed else A_buf[A_base0 + wi,
A_base1 + wk]
if ldmatrix_available: if ldmatrix_available:
T.ptx_ldmatrix( T.ptx_ldmatrix(
...@@ -263,15 +333,59 @@ class TensorCoreIntrinEmitter: ...@@ -263,15 +333,59 @@ class TensorCoreIntrinEmitter:
else: else:
for j in T.serial(local_size_a): for j in T.serial(local_size_a):
mi, mk = mma_load_layout(tx, j) mi, mk = mma_load_layout(tx, j)
A_local_buf[i * local_size_a + j] = A_shared_buf[wk + mk, wi + mi] if a_transposed:
A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wk + mk,
A_base1 + wi + mi]
else:
A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wi + mi,
A_base1 + wk + mk]
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) return _warp_ldmatrix_a(A_local_buf, A_region, ki, thread_binding, rk)
def ldmatrix_b(self, def ldmatrix_b(self,
B_local_buf: Buffer, B_local_buf: Buffer,
B_shared_buf: Buffer, B_shared_buf: Buffer | BufferRegion,
ki: PrimExpr, ki: PrimExpr,
rk: PrimExpr | None = 0): rk: PrimExpr | None = 0):
# Fast path for fp64: no ldmatrix support, do direct per-lane loads
if DataType(self.b_dtype).bits == 64:
warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols
chunk = self.chunk
micro_size_y = self.micro_size_y # 8
micro_size_k = self.micro_size_k # 4
local_size_b = self.local_size_b # 1
b_transposed = self.b_transposed
thread_binding = self.get_thread_binding()
# legalize shared buffer to region
B_region = to_buffer_region(B_shared_buf)
B_buf = B_region.buffer
B_base0 = B_region.region[-2].min
B_base1 = B_region.region[-1].min
@T.macro
def _warp_ld_b_fp64(
B_local_buf,
B_shared_buf,
ki,
thread_binding,
rk=0,
):
tx, warp_n, _ = self.extract_thread_binding(thread_binding)
for j in T.serial(warp_cols):
wi = warp_n * warp_col_tiles + j * micro_size_y
wk = rk * chunk + ki * micro_size_k
mi = tx // micro_size_k
mk = tx % micro_size_k
if b_transposed:
B_local_buf[j * local_size_b] = B_buf[B_base0 + wi + mi, B_base1 + wk + mk]
else:
B_local_buf[j * local_size_b] = B_buf[B_base0 + wk + mk, B_base1 + wi + mi]
return _warp_ld_b_fp64(B_local_buf, B_region, ki, thread_binding, rk)
warp_col_tiles = self.warp_col_tiles warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols warp_cols = self.warp_cols
chunk = self.chunk chunk = self.chunk
...@@ -281,6 +395,13 @@ class TensorCoreIntrinEmitter: ...@@ -281,6 +395,13 @@ class TensorCoreIntrinEmitter:
b_dtype = self.b_dtype b_dtype = self.b_dtype
b_transposed = self.b_transposed b_transposed = self.b_transposed
thread_binding = self.get_thread_binding() thread_binding = self.get_thread_binding()
# legalize shared buffer to region
B_region = to_buffer_region(B_shared_buf)
B_buf = B_region.buffer
B_base0 = B_region.region[-2].min
B_base1 = B_region.region[-1].min
B_stride_last = B_buf.shape[-1]
replicate_b = (self.n_dim == 16) replicate_b = (self.n_dim == 16)
# ldmatrix cannot be used for int8 + trans case. # ldmatrix cannot be used for int8 + trans case.
ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed) ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed)
...@@ -304,7 +425,7 @@ class TensorCoreIntrinEmitter: ...@@ -304,7 +425,7 @@ class TensorCoreIntrinEmitter:
thread_binding, thread_binding,
rk=0, rk=0,
): ):
stride = B_shared_buf.shape[-1] stride = B_stride_last
tx, warp_n, _ = self.extract_thread_binding(thread_binding) tx, warp_n, _ = self.extract_thread_binding(thread_binding)
trans = not b_transposed trans = not b_transposed
...@@ -316,8 +437,9 @@ class TensorCoreIntrinEmitter: ...@@ -316,8 +437,9 @@ class TensorCoreIntrinEmitter:
) )
if ldmatrix_available: if ldmatrix_available:
B_shared_buf_elem = B_shared_buf[wi, wk] if b_transposed else B_shared_buf[wk, B_shared_buf_elem = B_buf[B_base0 + wi,
wi] B_base1 + wk] if b_transposed else B_buf[B_base0 + wk,
B_base1 + wi]
T.ptx_ldmatrix( T.ptx_ldmatrix(
b_dtype, b_dtype,
...@@ -335,7 +457,12 @@ class TensorCoreIntrinEmitter: ...@@ -335,7 +457,12 @@ class TensorCoreIntrinEmitter:
# must be transposed. # must be transposed.
for j in T.serial(local_size_b): for j in T.serial(local_size_b):
mi, mk = mma_load_layout(tx, j) mi, mk = mma_load_layout(tx, j)
B_local_buf[i * local_size_b + j] = B_shared_buf[wk + mk, wi + mi] if b_transposed:
B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wi + mi,
B_base1 + wk + mk]
else:
B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wk + mk,
B_base1 + wi + mi]
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
...@@ -623,8 +750,10 @@ class TensorCoreIntrinEmitter: ...@@ -623,8 +750,10 @@ class TensorCoreIntrinEmitter:
from tilelang.utils import is_fragment from tilelang.utils import is_fragment
shape = local_buf.shape shape = local_buf.shape
assert is_fragment(
local_buf), f"local_buf {local_buf} must be a fragment, but got {local_buf.scope()}"
inverse_mma_store_layout = self.get_store_index_map(inverse=True) inverse_mma_store_layout = self.get_store_index_map(inverse=True)
assert is_fragment(local_buf), "local_buf must be a fragment"
micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y
local_size_out = self.local_size_out local_size_out = self.local_size_out
block_row_warps, block_col_warps = self.block_row_warps, self.block_col_warps block_row_warps, block_col_warps = self.block_row_warps, self.block_col_warps
......
...@@ -2,9 +2,10 @@ from __future__ import annotations ...@@ -2,9 +2,10 @@ from __future__ import annotations
import tilelang.language as T import tilelang.language as T
from typing import Literal, Callable from typing import Literal, Callable
from tvm import DataType from tvm import DataType
from tvm.tir import PrimExpr, IndexMap, Buffer, Var from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion
from tilelang import tvm as tvm
from tvm.runtime import convert from tvm.runtime import convert
from tilelang.utils import is_fragment from tilelang.utils import is_fragment, to_buffer_region
from tilelang.intrinsics.mma_sm70_layout import ( from tilelang.intrinsics.mma_sm70_layout import (
shared_16x4_to_mma_a_32x4_layout, shared_16x4_to_mma_a_32x4_layout,
shared_4x16_to_mma_b_32x4_layout, shared_4x16_to_mma_b_32x4_layout,
...@@ -188,7 +189,7 @@ class TensorCoreIntrinEmitter: ...@@ -188,7 +189,7 @@ class TensorCoreIntrinEmitter:
def ldmatrix_a(self, def ldmatrix_a(self,
A_local_buf: Buffer, A_local_buf: Buffer,
A_shared_buf: Buffer, A_shared_buf: Buffer | BufferRegion,
ki: PrimExpr, ki: PrimExpr,
rk: PrimExpr | None = 0): rk: PrimExpr | None = 0):
warp_row_tiles = self.warp_row_tiles warp_row_tiles = self.warp_row_tiles
...@@ -205,6 +206,12 @@ class TensorCoreIntrinEmitter: ...@@ -205,6 +206,12 @@ class TensorCoreIntrinEmitter:
mma_load_layout = mma_load_a_32x4_to_shared_16x4_layout mma_load_layout = mma_load_a_32x4_to_shared_16x4_layout
# legalize shared buffer to region
A_region = to_buffer_region(A_shared_buf)
A_buf = A_region.buffer
A_base0 = A_region.region[-2].min
A_base1 = A_region.region[-1].min
@T.macro @T.macro
def _warp_ldmatrix_a( def _warp_ldmatrix_a(
A_local_buf, A_local_buf,
...@@ -220,13 +227,13 @@ class TensorCoreIntrinEmitter: ...@@ -220,13 +227,13 @@ class TensorCoreIntrinEmitter:
wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k
for j in T.vectorized(local_size_a): for j in T.vectorized(local_size_a):
mi, mk = mma_load_layout(tx, j) mi, mk = mma_load_layout(tx, j)
A_local_buf[i * local_size_a + j] = A_shared_buf[wi + mi, wk + mk] A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wi + mi, A_base1 + wk + mk]
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) return _warp_ldmatrix_a(A_local_buf, A_region, ki, thread_binding, rk)
def ldmatrix_b(self, def ldmatrix_b(self,
B_local_buf: Buffer, B_local_buf: Buffer,
B_shared_buf: Buffer, B_shared_buf: Buffer | BufferRegion,
ki: PrimExpr, ki: PrimExpr,
rk: PrimExpr | None = 0): rk: PrimExpr | None = 0):
warp_col_tiles = self.warp_col_tiles warp_col_tiles = self.warp_col_tiles
...@@ -240,6 +247,12 @@ class TensorCoreIntrinEmitter: ...@@ -240,6 +247,12 @@ class TensorCoreIntrinEmitter:
mma_load_layout = mma_load_b_32x4_to_shared_16x4_layout_trans if b_transposed else mma_load_b_32x4_to_shared_4x16_layout mma_load_layout = mma_load_b_32x4_to_shared_16x4_layout_trans if b_transposed else mma_load_b_32x4_to_shared_4x16_layout
# legalize shared buffer to region
B_region = to_buffer_region(B_shared_buf)
B_buf = B_region.buffer
B_base0 = B_region.region[-2].min
B_base1 = B_region.region[-1].min
@T.macro @T.macro
def _warp_ldmatrix_b( def _warp_ldmatrix_b(
B_local_buf, B_local_buf,
...@@ -261,12 +274,14 @@ class TensorCoreIntrinEmitter: ...@@ -261,12 +274,14 @@ class TensorCoreIntrinEmitter:
for j in T.vectorized(local_size_b): for j in T.vectorized(local_size_b):
if b_transposed: if b_transposed:
mi, mk = mma_load_layout(tx, j) mi, mk = mma_load_layout(tx, j)
B_local_buf[i * local_size_b + j] = B_shared_buf[wi + mi, wk + mk] B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wi + mi,
B_base1 + wk + mk]
else: else:
mk, mi = mma_load_layout(tx, j) mk, mi = mma_load_layout(tx, j)
B_local_buf[i * local_size_b + j] = B_shared_buf[wk + mk, wi + mi] B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wk + mk,
B_base1 + wi + mi]
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) return _warp_ldmatrix_b(B_local_buf, B_region, ki, thread_binding, rk)
def mma(self, def mma(self,
A_local_buf: Buffer, A_local_buf: Buffer,
......
...@@ -3,7 +3,8 @@ from enum import IntEnum ...@@ -3,7 +3,8 @@ from enum import IntEnum
import tilelang.language as T import tilelang.language as T
from .mma_macro_generator import TensorCoreIntrinEmitter as MMAIntrinEmitter from .mma_macro_generator import TensorCoreIntrinEmitter as MMAIntrinEmitter
from tvm import DataType from tvm import DataType
from tvm.tir import PrimExpr, Buffer, Var from tvm.tir import PrimExpr, Buffer, Var, BufferLoad, BufferRegion
from tilelang import tvm as tvm
from tilelang import _ffi_api from tilelang import _ffi_api
from tilelang.utils import is_tensor_memory from tilelang.utils import is_tensor_memory
from tilelang.layout import ( from tilelang.layout import (
...@@ -245,13 +246,42 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -245,13 +246,42 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
mask_zero = T.Cast("int32", 0) mask_zero = T.Cast("int32", 0)
mask0 = mask1 = mask2 = mask3 = mask_zero mask0 = mask1 = mask2 = mask3 = mask_zero
# Helper to allow BufferRegion/BufferLoad as inputs
def access_ptr_from(buffer_or_load_or_region, access_type: str = "r"):
if isinstance(buffer_or_load_or_region, Buffer):
return buffer_or_load_or_region.access_ptr(access_type)
elif isinstance(buffer_or_load_or_region, BufferLoad):
buffer_load = buffer_or_load_or_region
offset, stride = 0, 1
buffer = buffer_load.buffer
for i, shape in enumerate(reversed(buffer.shape)):
indice = buffer_load.indices[len(buffer_load.indices) - i - 1]
if isinstance(indice, (tvm.tir.IntImm, tvm.tir.PrimExpr)):
offset += indice * stride
elif isinstance(indice, tvm.tir.Ramp):
offset += indice.base * stride
else:
raise ValueError(f"Unsupported index type: {type(indice)}")
stride *= shape
return buffer.access_ptr(access_type, offset=offset)
elif isinstance(buffer_or_load_or_region, BufferRegion):
buffer_region = buffer_or_load_or_region
buffer = buffer_region.buffer
offset, stride = 0, 1
for i, shape in enumerate(reversed(buffer.shape)):
offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride
stride *= shape
return buffer.access_ptr(access_type, offset=offset)
else:
raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}")
@T.macro @T.macro
def _warp_mma(A_buf, B_buf, C_local_buf, mbar): def _warp_mma(A_buf, B_buf, C_local_buf, mbar):
# Allocate SMEM descriptors for A and B # Allocate SMEM descriptors for A and B
desc_a = T.alloc_tcgen05_smem_desc() desc_a = T.alloc_tcgen05_smem_desc()
desc_b = T.alloc_tcgen05_smem_desc() desc_b = T.alloc_tcgen05_smem_desc()
A_ptr = A_buf.access_ptr("r") A_ptr = access_ptr_from(A_buf, "r")
B_ptr = B_buf.access_ptr("r") B_ptr = access_ptr_from(B_buf, "r")
T.initialize_tcgen05_descriptor( T.initialize_tcgen05_descriptor(
desc_a, desc_a,
......
...@@ -8,6 +8,7 @@ from .mma_layout import ( ...@@ -8,6 +8,7 @@ from .mma_layout import (
ldmatrix_32x16_to_shared_16x32_layout_a, ldmatrix_32x16_to_shared_16x32_layout_a,
ldmatrix_32x16_to_shared_16x32_layout_b, ldmatrix_32x16_to_shared_16x32_layout_b,
mma_store_32x8_to_shared_16x16_layout, mma_store_32x8_to_shared_16x16_layout,
mma_store_32x2_to_shared_8x8_layout_fp64,
) )
from .mfma_layout import (thread_id_shared_access_64x4_to_16x16_layout_C_n_m) from .mfma_layout import (thread_id_shared_access_64x4_to_16x16_layout_C_n_m)
...@@ -82,6 +83,10 @@ def mma_store_index_map(thread_id, local_id): ...@@ -82,6 +83,10 @@ def mma_store_index_map(thread_id, local_id):
return mma_store_32x8_to_shared_16x16_layout(thread_id, local_id) return mma_store_32x8_to_shared_16x16_layout(thread_id, local_id)
def mma_store_index_map_fp64(thread_id, local_id):
return mma_store_32x2_to_shared_8x8_layout_fp64(thread_id, local_id)
def mfma_store_index_map(thread_id, local_id): def mfma_store_index_map(thread_id, local_id):
return thread_id_shared_access_64x4_to_16x16_layout_C_n_m(thread_id, local_id) return thread_id_shared_access_64x4_to_16x16_layout_C_n_m(thread_id, local_id)
......
...@@ -4,8 +4,8 @@ from enum import IntEnum ...@@ -4,8 +4,8 @@ from enum import IntEnum
from typing import Callable from typing import Callable
from .mma_macro_generator import TensorCoreIntrinEmitter as MMAIntrinEmitter from .mma_macro_generator import TensorCoreIntrinEmitter as MMAIntrinEmitter
from tvm import DataType from tvm import DataType
from tvm.tir import PrimExpr, Buffer, Var, IndexMap from tvm.tir import PrimExpr, Buffer, Var, IndexMap, BufferRegion
from tilelang.utils import is_fragment from tilelang.utils import is_fragment, retrive_ptr_from_buffer_region, is_full_region
from math import gcd from math import gcd
from tilelang.layout import ( from tilelang.layout import (
Layout, Layout,
...@@ -161,14 +161,14 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -161,14 +161,14 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
raise ValueError(f"Unsupported swizzle mode: {layout}") raise ValueError(f"Unsupported swizzle mode: {layout}")
def wgmma(self, def wgmma(self,
A_buf: Buffer, A_region: BufferRegion,
B_buf: Buffer, B_region: BufferRegion,
C_local_buf: Buffer, C_region: BufferRegion,
clear_accum: PrimExpr = False, clear_accum: PrimExpr = False,
wg_wait: int = 0): wg_wait: int = 0):
if is_fragment(A_buf): if is_fragment(A_region):
return self.wgmma_rs(A_buf, B_buf, C_local_buf, clear_accum, wg_wait) return self.wgmma_rs(A_region, B_region, C_region, clear_accum, wg_wait)
local_size_out = self.local_size_out local_size_out = self.local_size_out
a_dtype_abbrv = self.a_dtype_abbrv a_dtype_abbrv = self.a_dtype_abbrv
...@@ -188,8 +188,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -188,8 +188,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
a_is_k_major = not self.a_transposed a_is_k_major = not self.a_transposed
b_is_k_major = self.b_transposed b_is_k_major = self.b_transposed
a_swizzle_mode = self._determinate_swizzle_mode(A_buf, self.a_shared_layout) a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout)
b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout) b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
elems_in_bits = DataType(self.a_dtype).bits elems_in_bits = DataType(self.a_dtype).bits
elems_in_bytes = elems_in_bits // 8 elems_in_bytes = elems_in_bits // 8
...@@ -263,26 +263,33 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -263,26 +263,33 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
thread_binding = self.get_thread_binding() thread_binding = self.get_thread_binding()
A_ptr = retrive_ptr_from_buffer_region(A_region)
B_ptr = retrive_ptr_from_buffer_region(B_region)
assert is_full_region(C_region), "Fragment output C must be a full region"
C_buf = C_region.buffer
@T.macro @T.macro
def _warp_mma(A_buf, B_buf, C_local_buf): def _warp_mma(A_ptr, B_ptr, C_buf):
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
desc_a = T.alloc_wgmma_desc() desc_a = T.alloc_wgmma_desc()
desc_b = T.alloc_wgmma_desc() desc_b = T.alloc_wgmma_desc()
T.initialize_wgmma_descriptor(desc_a, A_buf.access_ptr("r"), a_swizzle_mode, T.initialize_wgmma_descriptor(desc_a, A_ptr, a_swizzle_mode,
int(a_leading_byte_offset >> 4), int(a_leading_byte_offset >> 4),
int(a_stride_byte_offset >> 4)) int(a_stride_byte_offset >> 4))
T.initialize_wgmma_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode, T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode,
int(b_leading_byte_offset >> 4), int(b_leading_byte_offset >> 4),
int(b_stride_byte_offset >> 4)) int(b_stride_byte_offset >> 4))
T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs) T.warpgroup_fence_operand(C_buf, num_regs=accum_regs)
T.warpgroup_arrive() T.warpgroup_arrive()
for j in T.serial(num_inst_n):
for i in T.serial(num_inst_m): for j in T.unroll(num_inst_n):
for ki in T.serial(k_dim // micro_size_k): for i in T.unroll(num_inst_m):
for ki in T.unroll(k_dim // micro_size_k):
scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1))
warp_i = (warp_m // 4) * num_inst_m + i warp_i = (warp_m // 4) * num_inst_m + i
warp_j = warp_n * num_inst_n + j warp_j = warp_n * num_inst_n + j
scale_out = T.if_then_else(ki != 0, 1, T.if_then_else(clear_accum, 0, 1))
A_offset = ( A_offset = (
ki % ak_atom_size ki % ak_atom_size
) * micro_size_k + warp_i * 64 * a_swizzle_atom_elems + ( ) * micro_size_k + warp_i * 64 * a_swizzle_atom_elems + (
...@@ -290,24 +297,27 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -290,24 +297,27 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
) * m_dim * a_swizzle_atom_elems if a_is_k_major else warp_i * 64 * k_dim + ki * a_swizzle_atom_elems * micro_size_k ) * m_dim * a_swizzle_atom_elems if a_is_k_major else warp_i * 64 * k_dim + ki * a_swizzle_atom_elems * micro_size_k
B_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + ( B_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + (
ki % bk_atom_size ki % bk_atom_size
) * micro_size_k + warp_j * wgmma_inst_n * b_swizzle_atom_elems if b_is_k_major else ki * b_swizzle_atom_elems * micro_size_k + warp_j * k_dim * wgmma_inst_n ) * micro_size_k + warp_j * wgmma_inst_n * b_swizzle_atom_elems if b_is_k_major else (
ki * b_swizzle_atom_elems * micro_size_k + warp_j * wgmma_inst_n *
(k_dim if n_dim // b_swizzle_atom_elems > 1 else 1))
C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit
T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major, b_is_k_major, T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major, b_is_k_major,
a_dtype_abbrv, b_dtype_abbrv, accum_dtype_abbrv, desc_a.data, a_dtype_abbrv, b_dtype_abbrv, accum_dtype_abbrv, desc_a.data,
(A_offset * elems_in_bytes) >> 4, desc_b.data, (A_offset * elems_in_bytes) >> 4, desc_b.data,
(B_offset * elems_in_bytes) >> 4, C_local_buf.data, C_offset, (B_offset * elems_in_bytes) >> 4, C_buf.data, C_offset,
scale_out, scale_in_a, scale_in_b) scale_out, scale_in_a, scale_in_b)
T.warpgroup_commit_batch() T.warpgroup_commit_batch()
if wg_wait >= 0: if wg_wait >= 0:
T.warpgroup_wait(wg_wait) T.warpgroup_wait(wg_wait)
T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs) T.warpgroup_fence_operand(C_buf, num_regs=accum_regs)
return _warp_mma(A_buf, B_buf, C_local_buf) return _warp_mma(A_ptr, B_ptr, C_buf)
def wgmma_rs(self, def wgmma_rs(self,
A_buf: Buffer, A_region: BufferRegion,
B_buf: Buffer, B_region: BufferRegion,
C_local_buf: Buffer, C_region: BufferRegion,
clear_accum: PrimExpr = False, clear_accum: PrimExpr = False,
wg_wait: int = 0): wg_wait: int = 0):
local_size_a = self.local_size_a local_size_a = self.local_size_a
...@@ -333,7 +343,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -333,7 +343,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32 accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32
b_is_k_major = self.b_transposed b_is_k_major = self.b_transposed
b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout) b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none( b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none(
) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes ) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
...@@ -369,29 +379,37 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -369,29 +379,37 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
thread_binding = self.get_thread_binding() thread_binding = self.get_thread_binding()
assert is_full_region(A_region), "Fragment input A must be a full region"
assert is_full_region(C_region), "Fragment output C must be a full region"
A_buf = A_region.buffer
B_ptr = retrive_ptr_from_buffer_region(B_region)
C_buf = C_region.buffer
@T.macro @T.macro
def _warp_mma(A_buf, B_buf, C_local_buf): def _warp_mma(A_buf, B_ptr, C_buf):
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
desc_b = T.alloc_wgmma_desc() desc_b = T.alloc_wgmma_desc()
T.initialize_wgmma_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode, T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode,
int(b_leading_byte_offset >> 4), int(b_leading_byte_offset >> 4),
int(b_stride_byte_offset >> 4)) int(b_stride_byte_offset >> 4))
T.warpgroup_fence_operand(A_buf, num_regs=a_regs) T.warpgroup_fence_operand(A_buf, num_regs=a_regs)
T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs) T.warpgroup_fence_operand(C_buf, num_regs=accum_regs)
T.warpgroup_arrive() T.warpgroup_arrive()
for j in T.serial(0, num_inst_n): for j in T.unroll(0, num_inst_n):
for i in T.serial(num_inst_m): for i in T.unroll(num_inst_m):
for ki in T.serial(0, (k_dim // micro_size_k)): for ki in T.unroll(0, (k_dim // micro_size_k)):
warp_j = warp_n * num_inst_n + j warp_j = warp_n * num_inst_n + j
scale_out = T.if_then_else(ki != 0, 1, T.if_then_else(clear_accum, 0, 1)) scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1))
A_offset = ki * warp_rows * local_size_a + i * local_size_a A_offset = ki * warp_rows * local_size_a + i * local_size_a
B_offset = ( B_offset = (
ki // bk_atom_size ki // bk_atom_size
) * n_dim * b_swizzle_atom_elems + warp_j * wgmma_inst_n * b_swizzle_atom_elems + ( ) * n_dim * b_swizzle_atom_elems + warp_j * wgmma_inst_n * b_swizzle_atom_elems + (
ki % bk_atom_size ki % bk_atom_size) * micro_size_k if b_is_k_major else (
) * micro_size_k if b_is_k_major else ki * b_swizzle_atom_elems * micro_size_k + warp_j * k_dim * wgmma_inst_n ki * b_swizzle_atom_elems * micro_size_k + warp_j * wgmma_inst_n *
(k_dim if n_dim // b_swizzle_atom_elems > 1 else 1))
C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit
T.ptx_wgmma_rs( T.ptx_wgmma_rs(
accum_dtype, accum_dtype,
...@@ -404,19 +422,20 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -404,19 +422,20 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
A_offset, A_offset,
desc_b.data, desc_b.data,
(B_offset * elems_in_bytes) >> 4, (B_offset * elems_in_bytes) >> 4,
C_local_buf.data, C_buf.data,
C_offset, C_offset,
scale_out, scale_out,
scale_in_a, scale_in_a,
scale_in_b, scale_in_b,
) )
T.warpgroup_commit_batch() T.warpgroup_commit_batch()
if wg_wait >= 0: if wg_wait >= 0:
T.warpgroup_wait(wg_wait) T.warpgroup_wait(wg_wait)
T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs) T.warpgroup_fence_operand(C_buf, num_regs=accum_regs)
T.warpgroup_fence_operand(A_buf, num_regs=a_regs) T.warpgroup_fence_operand(A_buf, num_regs=a_regs)
return _warp_mma(A_buf, B_buf, C_local_buf) return _warp_mma(A_buf, B_ptr, C_buf)
def make_mma_load_layout(self, local_buf: Buffer, matrix: str = "A") -> T.Fragment: def make_mma_load_layout(self, local_buf: Buffer, matrix: str = "A") -> T.Fragment:
""" """
......
...@@ -8,7 +8,7 @@ from tilelang.utils.target import check_hip_availability ...@@ -8,7 +8,7 @@ from tilelang.utils.target import check_hip_availability
from tvm import DataType, tir from tvm import DataType, tir
from tvm.runtime import convert from tvm.runtime import convert
from typing import Any from typing import Any
from tvm.tir import PrimExpr, Var, Call, BufferLoad from tvm.tir import PrimExpr, Var, Call, BufferLoad, BufferRegion
_IS_HIP_AVAILABLE = check_hip_availability() _IS_HIP_AVAILABLE = check_hip_availability()
...@@ -440,21 +440,55 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr, ...@@ -440,21 +440,55 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
WGMMA operations by issuing an empty inline assembly barrier on every register. WGMMA operations by issuing an empty inline assembly barrier on every register.
Args: Args:
buffer_or_ptr: Buffer | PrimExpr buffer_or_ptr: Buffer | BufferLoad | BufferRegion | PrimExpr
Either a buffer representing the accumulator fragment or a pointer expression. A buffer representing the accumulator fragment, a buffer load/region
that identifies a starting element within the fragment, or a pointer expression
(e.g., tvm_access_ptr/address_of/typed Var).
offset: int | PrimExpr offset: int | PrimExpr
Element offset from the start of the accumulator fragment. Element offset from the start of the accumulator fragment.
num_regs: int | PrimExpr | None num_regs: int | PrimExpr | None
Number of 32-bit registers to fence. If None and a Buffer is provided, it will be Number of 32-bit registers to fence. If None and a Buffer is provided, it will be
derived from the buffer shape and dtype. derived from the buffer shape and dtype.
dtype: str | None dtype: str | None
Data type string of the accumulator elements. Required when passing a pointer. Data type string of the accumulator elements. When passing a buffer or
buffer-derived expression, dtype is inferred. It is required only when
passing a raw pointer expression that cannot be inferred.
Returns: Returns:
tir.Call: A handle to the warpgroup fence operation. tir.Call: A handle to the warpgroup fence operation.
""" """
if isinstance(buffer_or_ptr, BufferLoad): if isinstance(buffer_or_ptr, BufferLoad):
raise TypeError("Expected a buffer handle or pointer expression, got BufferLoad.") # Treat BufferLoad as a request to fence starting from the loaded element's address
buf = buffer_or_ptr.buffer
data_ptr = buf.data
inferred_dtype = buf.dtype
if dtype is not None and dtype != inferred_dtype:
raise ValueError(f"dtype mismatch: provided {dtype}, buffer uses {inferred_dtype}.")
dtype = inferred_dtype
# Compute element offset from indices using strides if present, otherwise row-major
if len(buf.strides) == len(buf.shape) and len(buf.strides) > 0:
elem_off = 0
for idx, stride in zip(buffer_or_ptr.indices, buf.strides):
elem_off = elem_off + idx * stride
else:
elem_off = 0
stride_acc = 1
for idx, dim in zip(reversed(buffer_or_ptr.indices), reversed(buf.shape)):
elem_off = elem_off + idx * stride_acc
stride_acc = stride_acc * dim
# Combine with user-provided offset
offset = elem_off + convert(offset)
if num_regs is None:
raise ValueError("num_regs must be provided when passing a BufferLoad.")
return evaluate(
tir.call_intrin(
"handle",
tir.op.Op.get("tl.warpgroup_fence_operand"),
dtype,
data_ptr,
convert(offset),
convert(num_regs),
))
if isinstance(buffer_or_ptr, tir.Buffer): if isinstance(buffer_or_ptr, tir.Buffer):
data_ptr = buffer_or_ptr.data data_ptr = buffer_or_ptr.data
...@@ -472,10 +506,78 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr, ...@@ -472,10 +506,78 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
"warpgroup_fence_operand requires num_regs when buffer shape is symbolic.") "warpgroup_fence_operand requires num_regs when buffer shape is symbolic.")
bits_per_elem = DataType(dtype).bits bits_per_elem = DataType(dtype).bits
num_regs = (total_elems * bits_per_elem + 31) // 32 num_regs = (total_elems * bits_per_elem + 31) // 32
elif isinstance(buffer_or_ptr, BufferRegion):
buf = buffer_or_ptr.buffer
data_ptr = buf.data
inferred_dtype = buf.dtype
if dtype is not None and dtype != inferred_dtype:
raise ValueError(f"dtype mismatch: provided {dtype}, buffer uses {inferred_dtype}.")
dtype = inferred_dtype
# Compute element offset from region min using strides if present, otherwise row-major
if len(buf.strides) == len(buf.shape) and len(buf.strides) > 0:
elem_off = 0
for r, stride in zip(buffer_or_ptr.region, buf.strides):
elem_off = elem_off + r.min * stride
else:
elem_off = 0
stride_acc = 1
for r, dim in zip(reversed(buffer_or_ptr.region), reversed(buf.shape)):
elem_off = elem_off + r.min * stride_acc
stride_acc = stride_acc * dim
# Combine with user-provided offset
offset = elem_off + convert(offset)
# Try derive num_regs from region extents if fully static; otherwise require user input
if num_regs is None:
total_elems = 1
static = True
for r in buffer_or_ptr.region:
if isinstance(r.extent, tir.IntImm):
total_elems *= int(r.extent)
else:
static = False
break
if static:
bits_per_elem = DataType(dtype).bits
num_regs = (total_elems * bits_per_elem + 31) // 32
else:
raise ValueError(
"warpgroup_fence_operand requires num_regs when BufferRegion extent is symbolic."
)
return evaluate(
tir.call_intrin(
"handle",
tir.op.Op.get("tl.warpgroup_fence_operand"),
dtype,
data_ptr,
convert(offset),
convert(num_regs),
))
else: else:
data_ptr = buffer_or_ptr data_ptr = buffer_or_ptr
# Try to infer dtype from common pointer expressions when not provided
if dtype is None: if dtype is None:
raise ValueError("dtype must be provided when passing a pointer expression.") inferred = None
# Case 1: Pointer from Buffer.access_ptr -> tir.builtin.tvm_access_ptr
if isinstance(data_ptr, Call) and data_ptr.op.same_as(tir.builtin.tvm_access_ptr()):
# args[0] is a type annotation call; its dtype carries the element dtype
inferred = str(data_ptr.args[0].dtype)
# Case 2: Pointer from tir.address_of(BufferLoad(...))
elif isinstance(data_ptr, Call) and data_ptr.op.same_as(tir.builtin.address_of()):
# args[0] should be a BufferLoad; its dtype is the element dtype
inferred = str(data_ptr.args[0].dtype)
# Case 3: Typed pointer Var with PrimType element (typed TIR)
elif hasattr(data_ptr, "type_annotation") and data_ptr.type_annotation is not None:
try:
elem_ty = getattr(data_ptr.type_annotation, "element_type", None)
if elem_ty is not None and hasattr(elem_ty, "dtype"):
inferred = str(elem_ty.dtype)
except Exception:
inferred = None
if inferred is None:
raise ValueError(
"dtype must be provided when passing a pointer expression and cannot be inferred."
)
dtype = inferred
if num_regs is None: if num_regs is None:
raise ValueError("num_regs must be provided when passing a pointer expression.") raise ValueError("num_regs must be provided when passing a pointer expression.")
......
...@@ -4,10 +4,19 @@ from __future__ import annotations ...@@ -4,10 +4,19 @@ from __future__ import annotations
from tilelang.primitives.gemm.base import GemmWarpPolicy from tilelang.primitives.gemm.base import GemmWarpPolicy
import tilelang.language as T import tilelang.language as T
from tvm import tir from tvm import tir
from tilelang.utils.language import get_buffer_region_from_load from tilelang.utils.language import (
to_buffer_region,
retrieve_shape,
def gemm_v1( retrieve_stride,
retrieve_ptr,
retrieve_offset,
prim_expr_equal,
)
from tilelang.env import env as _env
def _gemm_impl(
op_key: str,
A: tir.Buffer | tir.Var, A: tir.Buffer | tir.Var,
B: tir.Buffer | tir.Var, B: tir.Buffer | tir.Var,
C: tir.Buffer | tir.Var, C: tir.Buffer | tir.Var,
...@@ -19,30 +28,9 @@ def gemm_v1( ...@@ -19,30 +28,9 @@ def gemm_v1(
wg_wait: int = 0, wg_wait: int = 0,
mbar: tir.Buffer | None = None, mbar: tir.Buffer | None = None,
): ):
"""Perform a General Matrix Multiplication (GEMM) operation. """Shared GEMM implementation.
This function computes C = A @ B where A and B can optionally be transposed.
The operation supports various warp policies and accumulation modes.
Args:
A (Union[tir.Buffer, tir.Var]): First input matrix
B (Union[tir.Buffer, tir.Var]): Second input matrix
C (Union[tir.Buffer, tir.Var]): Output matrix for results
transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False.
transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False.
policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square.
clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False.
k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1.
wg_wait (int, optional): Warp group wait count. Defaults to 0.
On hopper it is equivalent to `wgmma.wait_group.sync.aligned <wg_wait>` if wg_wait is not -1
On sm100, `wg_wait` can only be 0 or -1. `mbarrier_wait(TCGEN5MMA barrier)` will be appended if wg_wait is 0.
mbar (tir.Buffer, optional): mbarrier for TCGEN5MMA synchronization
Returns: Returns a call_intrin handle for the given op key.
tir.Call: A handle to the GEMM operation
Raises:
AssertionError: If the K dimensions of matrices A and B don't match
""" """
def legalize_arguments(arg: tir.Buffer | tir.Var): def legalize_arguments(arg: tir.Buffer | tir.Var):
...@@ -63,52 +51,10 @@ def gemm_v1( ...@@ -63,52 +51,10 @@ def gemm_v1(
C = legalize_arguments(C) C = legalize_arguments(C)
mbar = legalize_arguments(mbar) if mbar is not None else None mbar = legalize_arguments(mbar) if mbar is not None else None
def retrieve_shape(object: tir.Buffer | tir.BufferRegion) -> list[int]: # Normalize A/B/C to BufferRegion to pass into tl.gemm
if isinstance(object, tir.Buffer): A = to_buffer_region(A)
return object.shape B = to_buffer_region(B)
elif isinstance(object, tir.BufferRegion): C = to_buffer_region(C)
region = object.region
shape = []
for r in region:
shape.append(r.extent)
return shape
elif isinstance(object, tir.BufferLoad):
region = get_buffer_region_from_load(object).region
shape = []
for r in region:
shape.append(r.extent)
return shape
else:
raise ValueError(
f"Unsupported retrieve_shape argument type: {type(object)} for buffer {object}")
def retrieve_stride(object: tir.Buffer | tir.BufferRegion) -> list[int]:
if isinstance(object, tir.Buffer):
strides = []
stride = 1
for s in reversed(object.shape):
strides.insert(0, stride)
stride *= s
return strides
elif isinstance(object, tir.BufferRegion):
buffer, _ = object.buffer, object.region
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
return strides
elif isinstance(object, tir.BufferLoad):
buffer = object.buffer
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
return strides
else:
raise ValueError(
f"Unsupported retrieve_stride argument type: {type(object)} for buffer {object}")
A_shape = retrieve_shape(A) A_shape = retrieve_shape(A)
B_shape = retrieve_shape(B) B_shape = retrieve_shape(B)
...@@ -132,68 +78,11 @@ def gemm_v1( ...@@ -132,68 +78,11 @@ def gemm_v1(
M, N = C_shape M, N = C_shape
K = A_shape[-2] if transpose_A else A_shape[-1] K = A_shape[-2] if transpose_A else A_shape[-1]
K_B = B_shape[-1] if transpose_B else B_shape[-2] K_B = B_shape[-1] if transpose_B else B_shape[-2]
assert K == K_B, f"T.gemm K shape check failed: K_A = {K}, K_B = {K_B}" assert prim_expr_equal(K, K_B), f"T.gemm K shape check failed: K_A = {K}, K_B = {K_B}"
stride_a = A_stride[-2] stride_a = A_stride[-2]
stride_b = B_stride[-2] stride_b = B_stride[-2]
def retrieve_ptr(object: tir.Buffer | tir.BufferRegion, access_type: str = "r") -> tir.PrimExpr:
if isinstance(object, tir.Buffer):
return object.access_ptr(access_type)
elif isinstance(object, tir.BufferRegion):
buffer, region = object.buffer, object.region
indices = []
for r in region:
indices.append(r.min)
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
offset = 0
# not offset the last two dimension
for i in range(len(indices) - 2):
offset += indices[i] * strides[i]
return buffer.access_ptr(access_mask=access_type, offset=offset)
elif isinstance(object, tir.BufferLoad):
buffer = object.buffer
region = get_buffer_region_from_load(object).region
indices = []
for r in region:
indices.append(r.min)
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
offset = 0
for i in range(len(indices) - 2):
offset += indices[i] * strides[i]
return buffer.access_ptr(access_mask=access_type, offset=offset)
else:
raise ValueError(
f"Unsupported retrieve_ptr argument type: {type(object)} for buffer {object}")
def retrieve_offset(object: tir.Buffer | tir.BufferRegion) -> tir.PrimExpr:
"""Retrieve the offset of the buffer or buffer region."""
if isinstance(object, tir.Buffer):
return [0] * len(object.shape)
elif isinstance(object, tir.BufferRegion):
_, region = object.buffer, object.region
indices = []
for r in region:
indices.append(r.min)
return indices
elif isinstance(object, tir.BufferLoad):
region = get_buffer_region_from_load(object).region
indices = []
for r in region:
indices.append(r.min)
return indices
else:
raise ValueError(
f"Unsupported retrieve_offset argument type: {type(object)} for buffer {object}")
A_offset = retrieve_offset(A) A_offset = retrieve_offset(A)
B_offset = retrieve_offset(B) B_offset = retrieve_offset(B)
assert A_offset[-2] == 0, "The offset of the first dimension of A must be 0" assert A_offset[-2] == 0, "The offset of the first dimension of A must be 0"
...@@ -201,18 +90,15 @@ def gemm_v1( ...@@ -201,18 +90,15 @@ def gemm_v1(
offset_a = A_offset[-1] offset_a = A_offset[-1]
offset_b = B_offset[-1] offset_b = B_offset[-1]
Aptr = retrieve_ptr(A, "r")
Bptr = retrieve_ptr(B, "r")
Cptr = retrieve_ptr(C, "rw")
mbarptr = retrieve_ptr(mbar, "rw") if mbar is not None else tir.const(0, "uint32") mbarptr = retrieve_ptr(mbar, "rw") if mbar is not None else tir.const(0, "uint32")
C_coords = [r.min for r in C.region] if isinstance(C, tir.BufferRegion) else [0, 0] C_coords = [r.min for r in C.region]
return tir.call_intrin("handle", tir.op.Op.get("tl.gemm"), Aptr, Bptr, Cptr, transpose_A, return tir.call_intrin("handle", tir.op.Op.get(op_key), A, B, C, transpose_A, transpose_B, M, N,
transpose_B, M, N, K, policy, clear_accum, stride_a, stride_b, offset_a, K, policy, clear_accum, stride_a, stride_b, offset_a, offset_b, k_pack,
offset_b, k_pack, wg_wait, mbarptr, C_coords[0], C_coords[1]) wg_wait, mbarptr, C_coords[0], C_coords[1])
# experimental currently, for fast compilation # Public wrappers
def gemm_v2( def gemm_v1(
A: tir.Buffer | tir.Var, A: tir.Buffer | tir.Var,
B: tir.Buffer | tir.Var, B: tir.Buffer | tir.Var,
C: tir.Buffer | tir.Var, C: tir.Buffer | tir.Var,
...@@ -224,214 +110,50 @@ def gemm_v2( ...@@ -224,214 +110,50 @@ def gemm_v2(
wg_wait: int = 0, wg_wait: int = 0,
mbar: tir.Buffer | None = None, mbar: tir.Buffer | None = None,
): ):
"""Perform a General Matrix Multiplication (GEMM) operation. """GEMM v1: use op tl.gemm."""
return _gemm_impl(
This function computes C = A @ B where A and B can optionally be transposed. "tl.gemm",
The operation supports various warp policies and accumulation modes. A,
B,
Args: C,
A (Union[tir.Buffer, tir.Var]): First input matrix transpose_A,
B (Union[tir.Buffer, tir.Var]): Second input matrix transpose_B,
C (Union[tir.Buffer, tir.Var]): Output matrix for results policy,
transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False. clear_accum,
transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False. k_pack,
policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square. wg_wait,
clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False. mbar,
k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1. )
wg_wait (int, optional): Warp group wait count. Defaults to 0.
mbar (tir.Buffer, optional): mbarrier for TCGEN5MMA synchronization
Returns:
tir.Call: A handle to the GEMM operation
Raises:
AssertionError: If the K dimensions of matrices A and B don't match
"""
def legalize_arguments(arg: tir.Buffer | tir.Var):
"""Convert let-bound variables to their corresponding buffers.
Args:
arg (Union[tir.Buffer, tir.Var]): Input argument to legalize
Returns:
Union[tir.Buffer, tir.Var]: The legalized argument
"""
if isinstance(arg, tir.Var) and T.has_let_value(arg):
return T.get_let_value(arg).buffer
return arg
A = legalize_arguments(A)
B = legalize_arguments(B)
C = legalize_arguments(C)
mbar = legalize_arguments(mbar) if mbar is not None else None
def retrieve_shape(object: tir.Buffer | tir.BufferRegion) -> list[int]:
if isinstance(object, tir.Buffer):
return object.shape
elif isinstance(object, tir.BufferRegion):
region = object.region
shape = []
for r in region:
shape.append(r.extent)
return shape
elif isinstance(object, tir.BufferLoad):
region = get_buffer_region_from_load(object).region
shape = []
for r in region:
shape.append(r.extent)
return shape
else:
raise ValueError(
f"Unsupported retrieve_shape argument type: {type(object)} for buffer {object}")
def retrieve_stride(object: tir.Buffer | tir.BufferRegion) -> list[int]:
if isinstance(object, tir.Buffer):
strides = []
stride = 1
for s in reversed(object.shape):
strides.insert(0, stride)
stride *= s
return strides
elif isinstance(object, tir.BufferRegion):
buffer, _ = object.buffer, object.region
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
return strides
elif isinstance(object, tir.BufferLoad):
buffer = object.buffer
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
return strides
else:
raise ValueError(
f"Unsupported retrieve_stride argument type: {type(object)} for buffer {object}")
A_shape = retrieve_shape(A)
B_shape = retrieve_shape(B)
C_shape = retrieve_shape(C)
A_stride = retrieve_stride(A)
B_stride = retrieve_stride(B)
assert len(C_shape) == 2, "current only support C as a 2D tensor"
assert len(A_shape) >= 2, "current only support A as a 2D or higher-order tensor"
assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor"
if len(A_shape) > 2:
for i in range(len(A_shape) - 2):
assert A_shape[i] == 1, \
"current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
if len(B_shape) > 2:
for i in range(len(B_shape) - 2):
assert B_shape[i] == 1, \
"current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
M, N = C_shape
K = A_shape[-2] if transpose_A else A_shape[-1]
K_B = B_shape[-1] if transpose_B else B_shape[-2]
assert K == K_B, f"T.gemm K shape check failed: K_A = {K}, K_B = {K_B}"
stride_a = A_stride[-2]
stride_b = B_stride[-2]
def retrieve_ptr(object: tir.Buffer | tir.BufferRegion, access_type: str = "r") -> tir.PrimExpr:
if isinstance(object, tir.Buffer):
return object.access_ptr(access_type)
elif isinstance(object, tir.BufferRegion):
buffer, region = object.buffer, object.region
indices = []
for r in region:
indices.append(r.min)
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
offset = 0
# not offset the last two dimension
for i in range(len(indices) - 2):
offset += indices[i] * strides[i]
return buffer.access_ptr(access_mask=access_type, offset=offset)
elif isinstance(object, tir.BufferLoad):
buffer = object.buffer
region = get_buffer_region_from_load(object).region
indices = []
for r in region:
indices.append(r.min)
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
offset = 0
for i in range(len(indices) - 2):
offset += indices[i] * strides[i]
return buffer.access_ptr(access_mask=access_type, offset=offset)
else:
raise ValueError(
f"Unsupported retrieve_ptr argument type: {type(object)} for buffer {object}")
def retrieve_offset(object: tir.Buffer | tir.BufferRegion) -> tir.PrimExpr:
"""Retrieve the offset of the buffer or buffer region."""
if isinstance(object, tir.Buffer):
return [0] * len(object.shape)
elif isinstance(object, tir.BufferRegion):
_, region = object.buffer, object.region
indices = []
for r in region:
indices.append(r.min)
return indices
elif isinstance(object, tir.BufferLoad):
region = get_buffer_region_from_load(object).region
indices = []
for r in region:
indices.append(r.min)
return indices
else:
raise ValueError(
f"Unsupported retrieve_offset argument type: {type(object)} for buffer {object}")
A_offset = retrieve_offset(A)
B_offset = retrieve_offset(B)
assert A_offset[-2] == 0, "The offset of the first dimension of A must be 0"
assert B_offset[-2] == 0, "The offset of the first dimension of B must be 0"
offset_a = A_offset[-1]
offset_b = B_offset[-1]
Aptr = retrieve_ptr(A, "r") # experimental currently, for fast compilation
Bptr = retrieve_ptr(B, "r") def gemm_v2(
Cptr = retrieve_ptr(C, "rw") A: tir.Buffer | tir.Var,
mbarptr = retrieve_ptr(mbar, "rw") if mbar is not None else tir.const(0, "uint32") B: tir.Buffer | tir.Var,
C_coords = [r.min for r in C.region] if isinstance(C, tir.BufferRegion) else [0, 0] C: tir.Buffer | tir.Var,
return tir.call_intrin( transpose_A: bool = False,
"handle", transpose_B: bool = False,
tir.op.Op.get("tl.gemm_py"), policy: GemmWarpPolicy = GemmWarpPolicy.Square,
Aptr, clear_accum: bool = False,
Bptr, k_pack: int = 1,
Cptr, wg_wait: int = 0,
mbar: tir.Buffer | None = None,
):
"""GEMM v2: use op tl.gemm_py."""
return _gemm_impl(
"tl.gemm_py",
A,
B,
C,
transpose_A, transpose_A,
transpose_B, transpose_B,
M,
N,
K,
policy, policy,
clear_accum, clear_accum,
stride_a,
stride_b,
offset_a,
offset_b,
k_pack, k_pack,
wg_wait, wg_wait,
mbarptr, mbar,
C_coords[0],
C_coords[1],
) )
gemm = gemm_v1 # Default to v2; allow forcing v1 via environment variable
gemm = gemm_v1 if _env.use_gemm_v1() else gemm_v2
"""Wrapping Layouts.""" """Wrapping Layouts."""
# pylint: disable=invalid-name, unsupported-binary-operation # pylint: disable=invalid-name, unsupported-binary-operation
from __future__ import annotations
import tvm import tvm
from tvm.tir import Buffer, BufferLoad, BufferRegion
from tilelang import _ffi_api from tilelang import _ffi_api
def _get_buffer_info(
buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion
) -> tuple[Buffer, list[int], str]:
"""
Extract buffer, shape, and dtype from Buffer, BufferLoad, or BufferRegion.
Args:
buffer_or_load_or_region: Can be Buffer, BufferLoad, or BufferRegion
Returns:
tuple: (buffer, shape, dtype)
"""
if isinstance(buffer_or_load_or_region, Buffer):
return buffer_or_load_or_region, buffer_or_load_or_region.shape, buffer_or_load_or_region.dtype
elif isinstance(buffer_or_load_or_region, (BufferLoad, BufferRegion)):
buf = buffer_or_load_or_region.buffer
return buf, buf.shape, buf.dtype
else:
raise TypeError(
f"Expected Buffer, BufferLoad, or BufferRegion, got {type(buffer_or_load_or_region)}")
def _get_stride_continuous(
buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> tuple[int, int]:
"""
Get stride (last 2nd dimension) and continuous (last dimension) from Buffer, BufferLoad, or BufferRegion.
Args:
buffer_or_load_or_region: Can be Buffer, BufferLoad, or BufferRegion
Returns:
tuple: (stride, continuous) as integers
"""
_, shape, _ = _get_buffer_info(buffer_or_load_or_region)
stride = int(shape[-2])
continuous = int(shape[-1])
return stride, continuous
def _get_element_size(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> int:
"""
Get element size in bits from Buffer, BufferLoad, or BufferRegion.
Args:
buffer_or_load_or_region: Can be Buffer, BufferLoad, or BufferRegion
Returns:
int: Element size in bits
"""
_, _, dtype = _get_buffer_info(buffer_or_load_or_region)
return int(tvm.DataType(dtype).bits)
# Use a stable swizzled layout to ensure consistent memory access patterns. # Use a stable swizzled layout to ensure consistent memory access patterns.
# Swizzling should be enabled or disabled based on whether TMA (Tensor Memory Access) is applied. # Swizzling should be enabled or disabled based on whether TMA (Tensor Memory Access) is applied.
def make_swizzled_layout(buffer: tvm.tir.Buffer, k_major: bool = True, allow_pad: bool = True): def make_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion,
assert len(buffer.shape) == 2 k_major: bool = True,
allow_pad: bool = True):
stride, continuous = _get_stride_continuous(buffer)
element_size = _get_element_size(buffer)
return _ffi_api.make_swizzled_layout( return _ffi_api.make_swizzled_layout(
int(buffer.shape[0]), stride,
int(buffer.shape[1]), continuous,
int(tvm.DataType(buffer.dtype).bits), element_size,
k_major, k_major,
allow_pad, allow_pad,
) )
# for Volta Intrinsics # for Volta Intrinsics
def make_volta_swizzled_layout(buffer: tvm.tir.Buffer, is_a: bool = True, k_inner: bool = True): def make_volta_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion,
assert len(buffer.shape) == 2 is_a: bool = True,
k_inner: bool = True):
stride, continuous = _get_stride_continuous(buffer)
return _ffi_api.make_volta_swizzled_layout( return _ffi_api.make_volta_swizzled_layout(
int(buffer.shape[0]), stride,
int(buffer.shape[1]), continuous,
is_a, is_a,
k_inner, k_inner,
) )
# for WGMMA Intrinsics # for WGMMA Intrinsics
def make_wgmma_swizzled_layout(buffer: tvm.tir.Buffer, def make_wgmma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion,
continuity: int = None, continuity: int = None,
k_major: bool = True): k_major: bool = True):
assert len(buffer.shape) == 2 stride, continuous = _get_stride_continuous(buffer)
element_size = _get_element_size(buffer)
if continuity is None: if continuity is None:
continuity = int(buffer.shape[1]) continuity = continuous
return _ffi_api.make_wgmma_swizzled_layout( return _ffi_api.make_wgmma_swizzled_layout(
int(buffer.shape[0]), stride,
int(buffer.shape[1]), continuous,
continuity, continuity,
int(tvm.DataType(buffer.dtype).bits), element_size,
k_major, k_major,
) )
# for TCGEN05MMA Intrinsics # for TCGEN05MMA Intrinsics
def make_tcgen05mma_swizzled_layout(buffer: tvm.tir.Buffer, def make_tcgen05mma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion,
continuity: int = None, continuity: int = None,
k_major: bool = True): k_major: bool = True):
assert len(buffer.shape) == 2 stride, continuous = _get_stride_continuous(buffer)
element_size = _get_element_size(buffer)
if continuity is None: if continuity is None:
continuity = int(buffer.shape[1]) continuity = continuous
return _ffi_api.make_tcgen05mma_swizzled_layout( return _ffi_api.make_tcgen05mma_swizzled_layout(
int(buffer.shape[0]), stride,
int(buffer.shape[1]), continuous,
continuity, continuity,
int(tvm.DataType(buffer.dtype).bits), element_size,
k_major, k_major,
) )
...@@ -66,15 +128,14 @@ def make_tcgen05mma_swizzled_layout(buffer: tvm.tir.Buffer, ...@@ -66,15 +128,14 @@ def make_tcgen05mma_swizzled_layout(buffer: tvm.tir.Buffer,
def make_full_bank_swizzled_layout(*args): def make_full_bank_swizzled_layout(*args):
""" """
Args: Args:
args: buffer or (stride, continuous, element_size) args: buffer/BufferLoad/BufferRegion or (stride, continuous, element_size)
Examples: Examples:
make_full_bank_swizzled_layout(buffer) make_full_bank_swizzled_layout(buffer)
make_full_bank_swizzled_layout(stride, continuous, element_size) make_full_bank_swizzled_layout(stride, continuous, element_size)
""" """
if len(args) == 1: if len(args) == 1:
buffer = args[0] stride, continuous = _get_stride_continuous(args[0])
stride, continuous = int(buffer.shape[0]), int(buffer.shape[1]) element_size = _get_element_size(args[0])
element_size = int(tvm.DataType(buffer.dtype).bits)
elif len(args) == 3: elif len(args) == 3:
stride, continuous, element_size = args stride, continuous, element_size = args
else: else:
...@@ -91,15 +152,14 @@ def make_full_bank_swizzled_layout(*args): ...@@ -91,15 +152,14 @@ def make_full_bank_swizzled_layout(*args):
def make_half_bank_swizzled_layout(*args): def make_half_bank_swizzled_layout(*args):
""" """
Args: Args:
args: buffer or (stride, continuous, element_size) args: buffer/BufferLoad/BufferRegion or (stride, continuous, element_size)
Examples: Examples:
make_half_bank_swizzled_layout(buffer) make_half_bank_swizzled_layout(buffer)
make_half_bank_swizzled_layout(stride, continuous, element_size) make_half_bank_swizzled_layout(stride, continuous, element_size)
""" """
if len(args) == 1: if len(args) == 1:
buffer = args[0] stride, continuous = _get_stride_continuous(args[0])
stride, continuous = int(buffer.shape[0]), int(buffer.shape[1]) element_size = _get_element_size(args[0])
element_size = int(tvm.DataType(buffer.dtype).bits)
elif len(args) == 3: elif len(args) == 3:
stride, continuous, element_size = args stride, continuous, element_size = args
else: else:
...@@ -116,15 +176,14 @@ def make_half_bank_swizzled_layout(*args): ...@@ -116,15 +176,14 @@ def make_half_bank_swizzled_layout(*args):
def make_quarter_bank_swizzled_layout(*args): def make_quarter_bank_swizzled_layout(*args):
""" """
Args: Args:
args: buffer or (stride, continuous, element_size) args: buffer/BufferLoad/BufferRegion or (stride, continuous, element_size)
Examples: Examples:
make_quarter_bank_swizzled_layout(buffer) make_quarter_bank_swizzled_layout(buffer)
make_quarter_bank_swizzled_layout(stride, continuous, element_size) make_quarter_bank_swizzled_layout(stride, continuous, element_size)
""" """
if len(args) == 1: if len(args) == 1:
buffer = args[0] stride, continuous = _get_stride_continuous(args[0])
stride, continuous = int(buffer.shape[0]), int(buffer.shape[1]) element_size = _get_element_size(args[0])
element_size = int(tvm.DataType(buffer.dtype).bits)
elif len(args) == 3: elif len(args) == 3:
stride, continuous, element_size = args stride, continuous, element_size = args
else: else:
...@@ -139,14 +198,13 @@ def make_quarter_bank_swizzled_layout(*args): ...@@ -139,14 +198,13 @@ def make_quarter_bank_swizzled_layout(*args):
def make_linear_layout(*args): def make_linear_layout(*args):
""" """
Args: Args:
args: buffer or (stride, continuous) args: buffer/BufferLoad/BufferRegion or (stride, continuous)
Examples: Examples:
make_linear_layout(buffer) make_linear_layout(buffer)
make_linear_layout(stride, continuous) make_linear_layout(stride, continuous)
""" """
if len(args) == 1: if len(args) == 1:
buffer = args[0] stride, continuous = _get_stride_continuous(args[0])
stride, continuous = int(buffer.shape[0]), int(buffer.shape[1])
elif len(args) == 2: elif len(args) == 2:
stride, continuous = args stride, continuous = args
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment