Unverified Commit 91a7bb2b authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[TileOp] Introduce a experimental python defined `T.gemm_v2` (#793)

* Refactor GEMM and GEMM-SP operations to enhance clarity and maintainability

- Removed deprecated prime factorization functions from `gemm.cc` and `gemm_sp.cc`.
- Introduced a new `GemmWarpPolicy` class to manage warp policy attributes and methods, improving encapsulation.
- Updated reflection methods to include the new policy structure, ensuring proper registration and introspection capabilities.
- Enhanced `GetArchInt` function in `utils.cc` for better readability and type safety.
- Added new `gemm_v2` function in `gemm.py` for improved GEMM operation with additional parameters and checks.

* Refactor GEMM and frontend legalize operations for improved clarity and functionality

- Updated `gemm_py.h` to include the correct header for GEMM operations.
- Renamed `FrontendLegalizer` class to `LetInliner` and updated related methods to reflect this change, enhancing code clarity.
- Modified the pass function from `FrontendLegalize` to `LetInline` for better alignment with its purpose.
- Updated test cases to utilize the new `gemm_v2` function and adjusted the testing framework for improved output and clarity.
- Removed obsolete test file `test_tilelang_transform_frontend_legalize.py` to streamline the test suite.
- Enhanced the `LowerAndLegalize` function to utilize the new `LetInline` pass, improving the overall transformation process.

* Enhance CUDA code generation and testing for GEMM operations

- Added indentation printing in `codegen_cuda.cc` for improved assembly code formatting.
- Updated `test_tilelang_tilelibrary_gemm.py` to include additional GEMM test cases and shared memory allocation with specified scope.
- Introduced new `matmul_sr` and `run_gemm_sr` functions for GEMM operations with shared and fragment memory layouts.
- Refactored layout inference in `mma_macro_generator.py` to improve clarity and correctness in shared memory handling.
- Enhanced `gemm/__init__.py` to support new GEMM operation combinations and layout inference logic.

These changes improve the clarity, functionality, and testing coverage of GEMM operations in the TileLang framework.

* Refactor GEMM layout and testing for improved clarity and functionality

- Updated `gemm_layouts.cc` to enhance the layout generation logic for transposed and non-transposed GEMM operations.
- Renamed and modified functions in `test_tilelang_tilelibrary_gemm.py` to reflect changes in GEMM function signatures and improve test coverage.
- Introduced new GEMM operation combinations in `gemm/__init__.py` to support additional layouts and configurations.
- Enhanced layout inference in `mma_layout.py` and `mma_macro_generator.py` for better handling of shared memory layouts.

These changes improve the clarity, functionality, and testing coverage of GEMM operations in the TileLang framework.

* Refactor GEMM layout and Python integration for improved functionality

- Updated `gemm_layouts.cc` to correct the order of layout replication and repetition for transposed and non-transposed GEMM operations.
- Enhanced `gemm_py.cc` to handle block realization more robustly, ensuring correct assignment of global symbols and block attributes.
- Refactored `inject_pipeline.cc` to streamline buffer read/write region handling, improving clarity and maintainability.
- Cleaned up test cases in `test_tilelang_tilelibrary_gemm.py` by removing unnecessary print statements and adjusting function calls for better test execution flow.

These changes enhance the clarity, functionality, and robustness of GEMM operations and their testing in the TileLang framework.

* Refactor GEMM layout and testing for improved clarity and functionality

- Updated `gemm_layouts.cc` to enhance layout generation logic for transposed and non-transposed GEMM operations.
- Improved block realization handling in `gemm_py.cc` for better assignment of global symbols.
- Streamlined buffer read/write region handling in `inject_pipeline.cc` for clarity.
- Enhanced test cases in `test_tilelang_tilelibrary_gemm.py` by adjusting function calls and adding new GEMM operation combinations.

These changes improve the clarity, functionality, and robustness of GEMM operations and their testing in the TileLang framework.

* tfloat32 support.

* lint fix

* lint fix

* Refactor shared memory allocation in GEMM tests

- Removed unnecessary scope specification in shared memory allocation for matrices A and B in `test_tilelang_tilelibrary_gemm.py`.
- This change simplifies the allocation process and aligns with the updated GEMM function signatures.
parent 9fd6bb30
...@@ -41,6 +41,7 @@ Checks: > ...@@ -41,6 +41,7 @@ Checks: >
-clang-analyzer-optin.cplusplus.UninitializedObject, -clang-analyzer-optin.cplusplus.UninitializedObject,
-cppcoreguidelines-pro-type-static-cast-downcast, -cppcoreguidelines-pro-type-static-cast-downcast,
-performance-unnecessary-value-param, -performance-unnecessary-value-param,
-performance-enum-size,
WarningsAsErrors: '*' WarningsAsErrors: '*'
......
Subproject commit 1fc7578cd1ff934455b07597508b5a67d7cb5a73 Subproject commit eddefbd65acb7b1ea51dd18068b4049754c4fa7a
...@@ -132,6 +132,7 @@ tilelang_file_glob(GLOB TILE_LANG_SRCS ...@@ -132,6 +132,7 @@ tilelang_file_glob(GLOB TILE_LANG_SRCS
if(USE_CUDA) if(USE_CUDA)
tilelang_file_glob(GLOB TILE_LANG_CUDA_SRCS tilelang_file_glob(GLOB TILE_LANG_CUDA_SRCS
src/runtime/*.cc src/runtime/*.cc
src/target/ptx.cc
src/target/codegen_cuda.cc src/target/codegen_cuda.cc
src/target/rt_mod_cuda.cc src/target/rt_mod_cuda.cc
) )
......
...@@ -402,7 +402,6 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, ...@@ -402,7 +402,6 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
PassContext pass_ctx = PassContext::Current(); PassContext pass_ctx = PassContext::Current();
bool disable_tma_lower = bool disable_tma_lower =
pass_ctx->GetConfig<bool>(kDisableTMALower, false).value(); pass_ctx->GetConfig<bool>(kDisableTMALower, false).value();
auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma, auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma,
T.layout_map, T.analyzer, T.buffer_oob); T.layout_map, T.analyzer, T.buffer_oob);
if (copy_inst == CopyInst::kBulkLoad || copy_inst == CopyInst::kBulkStore) { if (copy_inst == CopyInst::kBulkLoad || copy_inst == CopyInst::kBulkStore) {
......
...@@ -18,30 +18,6 @@ namespace tl { ...@@ -18,30 +18,6 @@ namespace tl {
using namespace tir; using namespace tir;
/**
* @brief Compute the prime factorization of an integer.
*
* Returns the prime factors of x in non-decreasing order by repeatedly dividing
* out the smallest possible factor.
*
* @param x Integer to factorize. If x <= 1, an empty vector is returned.
* @return std::vector<int> Prime factors of x (with multiplicity), in
* non-decreasing order.
*/
static std::vector<int> toPrimeFactors(int x) {
int i = 2;
std::vector<int> result;
while (x > 1) {
if (x % i == 0) {
x /= i;
result.push_back(i);
} else {
i++;
}
}
return result;
}
/** /**
* @brief Construct a Gemm operator from serialized TL arguments and a buffer * @brief Construct a Gemm operator from serialized TL arguments and a buffer
* map. * map.
...@@ -268,7 +244,6 @@ GemmWarpPolicyNode::ComputeWarpPartition(int M, int N, int block_size, ...@@ -268,7 +244,6 @@ GemmWarpPolicyNode::ComputeWarpPartition(int M, int N, int block_size,
int best_m = 1; int best_m = 1;
int best_n = 1; int best_n = 1;
float best_balance = std::numeric_limits<float>::max(); float best_balance = std::numeric_limits<float>::max();
// Try all possible combinations that satisfy the constraints // Try all possible combinations that satisfy the constraints
for (int m = 1; m <= max_m_warps && m <= num_warps; m++) { for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
int n = num_warps / m; int n = num_warps / m;
...@@ -276,6 +251,13 @@ GemmWarpPolicyNode::ComputeWarpPartition(int M, int N, int block_size, ...@@ -276,6 +251,13 @@ GemmWarpPolicyNode::ComputeWarpPartition(int M, int N, int block_size,
// Calculate how balanced this partition is // Calculate how balanced this partition is
float m_per_warp = static_cast<float>(M) / (m * kMPerWarp); float m_per_warp = static_cast<float>(M) / (m * kMPerWarp);
float n_per_warp = static_cast<float>(N) / (n * kNPerWarp); float n_per_warp = static_cast<float>(N) / (n * kNPerWarp);
// m_per_warp and n_per_warp must be greater than 1
if (m_per_warp < 1 || n_per_warp < 1)
continue;
// m * n must equal num_warps
if (m * n != num_warps)
continue;
float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio); float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);
if (balance < best_balance) { if (balance < best_balance) {
...@@ -290,7 +272,6 @@ GemmWarpPolicyNode::ComputeWarpPartition(int M, int N, int block_size, ...@@ -290,7 +272,6 @@ GemmWarpPolicyNode::ComputeWarpPartition(int M, int N, int block_size,
} else { } else {
ICHECK(0) << "Unknown GemmWarpPolicy"; ICHECK(0) << "Unknown GemmWarpPolicy";
} }
// Store the computed values in the object's member variables // Store the computed values in the object's member variables
this->m_warp = m_warp; this->m_warp = m_warp;
this->n_warp = n_warp; this->n_warp = n_warp;
...@@ -632,5 +613,21 @@ TIR_REGISTER_TL_OP(Gemm, gemm) ...@@ -632,5 +613,21 @@ TIR_REGISTER_TL_OP(Gemm, gemm)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TVM_REGISTER_OP("tl.GemmWarpPolicy")
.set_attr<TScriptPrinterName>("TScriptPrinterName", "GemmWarpPolicy");
TVM_FFI_STATIC_INIT_BLOCK({
GemmNode::RegisterReflection();
GemmWarpPolicyNode::RegisterReflection();
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.GemmWarpPolicyComputeWarpPartition",
[](GemmWarpPolicy policy, int M, int N, int block_size,
Target target, bool is_wgmma) {
policy->ComputeWarpPartition(M, N, block_size, target,
is_wgmma);
return;
});
});
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
/*!
* \file tl/op/gemm_py.cc
* \brief Implementation of General Matrix Multiplication (GEMM) operators
*/
#include "gemm_py.h"
#include "builtin.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/transform.h>
#include "../target/utils.h"
#include "tvm/ffi/string.h"
namespace tvm {
namespace tl {
using namespace tir;
/**
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
* map.
*
* This constructor deserializes operator parameters from `args` and resolves
* buffer references via `vmap`, populating an internal GemmPyNode with:
* - device pointers for A, B, C and their corresponding Buffer objects,
* - transpose flags for A and B,
* - matrix dimensions M, N, K,
* - warp allocation policy and clear_accum flag,
* - strides and memory offsets for A and B,
* - optional kPack (must be 1 or 2) and optional wg_wait.
*
* The populated GemmPyNode is stored into the wrapper's internal `data_`.
*
* @param args Positional serialized arguments produced by the TL frontend:
* expected layout is:
* [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool),
* M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
* stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
* (optional) kPack (Int), (optional) wg_wait (Int)]
* @param vmap Mapping from access pointer vars to Buffer objects used to
* resolve the Buffer corresponding to each pointer argument.
*
* @note If `kPack` is provided it must be 1 or 2; otherwise the constructor
* fails with an ICHECK (runtime assertion). No other validation is
* performed here.
*/
GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmPyNode> node = make_object<GemmPyNode>();
node->Aptr = args[0];
node->Bptr = args[1];
node->Cptr = args[2];
node->A = vmap[GetVarFromAccessPtr(node->Aptr)];
node->B = vmap[GetVarFromAccessPtr(node->Bptr)];
node->C = vmap[GetVarFromAccessPtr(node->Cptr)];
node->trans_A = args[3].as<Bool>().value();
node->trans_B = args[4].as<Bool>().value();
node->M = args[5].as<IntImm>().value()->value;
node->N = args[6].as<IntImm>().value()->value;
node->K = args[7].as<IntImm>().value()->value;
node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
node->clear_accum = args[9].as<Bool>().value();
node->stride_A = args[10].as<IntImm>().value()->value;
node->stride_B = args[11].as<IntImm>().value()->value;
node->offset_A = args[12].as<IntImm>().value()->value;
node->offset_B = args[13].as<IntImm>().value()->value;
if (args.size() > 14) {
node->kPack = args[14].as<IntImm>().value()->value;
if (node->kPack != 1 && node->kPack != 2) {
ICHECK(false) << "kPack must be 1 or 2";
}
}
if (args.size() > 15) {
node->wg_wait = args[15].as<IntImm>().value()->value;
}
data_ = std::move(node);
}
/**
* @brief Create a copy of this GemmPyNode as a TileOperator.
*
* Constructs a new GemmPyNode by copying the current node state and returns it
* wrapped in a Gemm TileOperator.
*
* @return TileOperator A Gemm operator that owns a copy of this node.
*/
TileOperator GemmPyNode::Clone() const {
auto op = make_object<GemmPyNode>(*this);
return GemmPy(op);
}
GemmPyNode::GemmInst GemmPyNode::GetGemmInst(int block_size,
Target target) const {
int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size;
bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
(num_warps % 4 == 0) && CheckWGMMA();
if (allow_wgmma) {
return GemmInst::kWGMMA;
} else if (TargetIsCDNA(target)) {
return GemmInst::kMFMA;
} else if (TargetIsCuda(target)) {
return GemmInst::kMMA;
} else {
ICHECK(0) << "Unsupported target for gemm: " << target->str();
}
}
/**
* @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM.
*
* Evaluates device-memory placement, data-type combinations, transpose flags,
* and K divisibility constraints required for the Hopper WGMMA code path.
*
* The check returns true only when:
* - B resides in shared memory ("shared" or "shared.dyn"); and
* - (C, A, B) dtypes match one of the supported combinations below and K
* satisfies the required alignment; and
* - for combinations that require specific orientations, A is not transposed
* and B is transposed.
*
* Supported combinations and constraints:
* - C=float16:
* - A=float16, B=float16: K % 16 == 0
* - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K %
* 32 == 0
* - C=float32:
* - A=float16, B=float16: K % 16 == 0
* - A=bfloat16, B=bfloat16: K % 16 == 0
* - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0
* - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0
* - C=int32:
* - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B)
* and K % 32 == 0
*
* @return true if WGMMA is supported for the current buffers, dtypes, and
* transpose/shape constraints; false otherwise.
*/
bool GemmPyNode::CheckWGMMA() const {
if (B.scope() != "shared.dyn" && B.scope() != "shared") {
return false;
}
if (C->dtype == DataType::Float(16)) {
if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
return K % 16 == 0;
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
return (!trans_A) && trans_B && K % 32 == 0;
else
return false;
} else if (C->dtype == DataType::Float(32)) {
if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
return K % 16 == 0;
else if (A->dtype == DataType::BFloat(16) &&
B->dtype == DataType::BFloat(16))
return K % 16 == 0;
else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32))
return (!trans_A) && trans_B && K % 8 == 0;
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
return (!trans_A) && trans_B && K % 32 == 0;
else
return false;
} else if (C->dtype == DataType::Int(32)) {
if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8))
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8))
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8))
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8))
return (!trans_A) && trans_B && K % 32 == 0;
else
return false;
} else {
return false;
}
}
/**
* @brief Parse and return the numeric GPU architecture from a Target's "arch"
* attribute.
*
* Examines the target's "arch" string and, if it matches the pattern
* "sm_<num>", returns <num> as an int. If the attribute is present but does not
* match that pattern, returns 0.
*
* Preconditions: the target must have an "arch" attribute (this is checked via
* ICHECK).
*
* @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if
* the arch string does not match "sm_<num>".
*/
static int GetArchInt(Target target) {
int arch_int = 0;
auto s = target->GetAttr<String>("arch");
ICHECK(s.defined());
std::string arch = s.value();
if (arch.rfind("sm_", 0) == 0) {
arch_int = std::stoi(arch.substr(3));
} else {
arch_int = 0;
}
return arch_int;
}
Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto block_size = *as_const_int(T.thread_bounds->extent);
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
auto [warp_m, warp_n] = policy->ComputeWarpPartition(
M, N, block_size, T.target, gemm_inst == GemmInst::kWGMMA);
if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) {
auto prim_func = Downcast<PrimFunc>(
(*f)(GetRef<GemmPy>(this), T.target, T.thread_bounds, T.thread_var));
ICHECK(prim_func->attrs.defined());
auto global_symbol = prim_func->attrs.GetAttr<String>("global_symbol");
ICHECK(global_symbol.defined());
if (prim_func->body.as<BlockRealizeNode>()) {
BlockRealize block_realize = Downcast<BlockRealize>(prim_func->body);
auto block = block_realize->block;
{
BlockNode *n = block.CopyOnWrite();
n->name_hint = global_symbol.value();
}
return BlockRealize(block_realize->iter_values, block_realize->predicate,
block);
}
// warp with block realize node
return BlockRealize(
/*iter_values=*/Array<PrimExpr>(),
/*predicate=*/const_true(),
/*block=*/
Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
/*name_hint=*/global_symbol.value(), prim_func->body));
} else {
LOG(FATAL) << "No lower function found for gemm_py";
}
}
LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (completed_)
return {};
LayoutMap results;
if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.infer_layout")) {
results = Downcast<LayoutMap>(
(*f)(GetRef<GemmPy>(this), T.target, T.thread_bounds));
} else {
LOG(FATAL) << "No infer layout function found for gemm_py";
}
completed_ = true;
return results;
}
TIR_REGISTER_TL_OP(GemmPy, gemm_py)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({ GemmPyNode::RegisterReflection(); });
} // namespace tl
} // namespace tvm
/*!
* \file tl/op/gemm_py.h
* \brief Define gemm operator.
*
*/
#ifndef TVM_TL_OP_GEMM_PY_H_
#define TVM_TL_OP_GEMM_PY_H_
#include "gemm.h"
#include "operator.h"
namespace tvm {
namespace tl {
using namespace tir;
class GemmPyNode : public TileOperatorNode {
public:
bool CheckWGMMA() const;
tir::Buffer A, B, C;
// pointer to the A, B, C
PrimExpr Aptr, Bptr, Cptr;
bool trans_A, trans_B;
int M, N, K;
int stride_A, stride_B;
int offset_A, offset_B;
bool clear_accum = false;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
int kPack = 1;
int wg_wait = 0;
mutable GemmWarpPolicy policy;
static constexpr const char *_type_key = "tl.GemmPy";
TVM_DECLARE_FINAL_OBJECT_INFO(GemmPyNode, TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GemmPyNode>()
.def_ro("A", &GemmPyNode::A)
.def_ro("B", &GemmPyNode::B)
.def_ro("C", &GemmPyNode::C)
.def_ro("Aptr", &GemmPyNode::Aptr)
.def_ro("Bptr", &GemmPyNode::Bptr)
.def_ro("Cptr", &GemmPyNode::Cptr)
.def_ro("trans_A", &GemmPyNode::trans_A)
.def_ro("trans_B", &GemmPyNode::trans_B)
.def_ro("M", &GemmPyNode::M)
.def_ro("N", &GemmPyNode::N)
.def_ro("K", &GemmPyNode::K)
.def_ro("stride_A", &GemmPyNode::stride_A)
.def_ro("stride_B", &GemmPyNode::stride_B)
.def_ro("offset_A", &GemmPyNode::offset_A)
.def_ro("offset_B", &GemmPyNode::offset_B)
.def_ro("clear_accum", &GemmPyNode::clear_accum)
.def_ro("kPack", &GemmPyNode::kPack)
.def_ro("wg_wait", &GemmPyNode::wg_wait)
.def_ro("policy", &GemmPyNode::policy);
}
bool SEqualReduce(const GemmPyNode *other, SEqualReducer equal) const {
return equal(A, other->A) && equal(B, other->B) && equal(C, other->C) &&
equal(Aptr, other->Aptr) && equal(Bptr, other->Bptr) &&
equal(Cptr, other->Cptr) && equal(trans_A, other->trans_A) &&
equal(trans_B, other->trans_B) && equal(M, other->M) &&
equal(N, other->N) && equal(K, other->K) &&
equal(stride_A, other->stride_A) &&
equal(stride_B, other->stride_B) &&
equal(offset_A, other->offset_B) &&
equal(offset_B, other->offset_B) &&
equal(clear_accum, other->clear_accum) &&
equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) &&
equal(policy, other->policy);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(A);
hash_reduce(B);
hash_reduce(C);
hash_reduce(Aptr);
hash_reduce(Bptr);
hash_reduce(Cptr);
hash_reduce(trans_A);
hash_reduce(trans_B);
hash_reduce(M);
hash_reduce(N);
hash_reduce(K);
hash_reduce(stride_A);
hash_reduce(stride_B);
hash_reduce(offset_A);
hash_reduce(offset_B);
hash_reduce(clear_accum);
hash_reduce(kPack);
hash_reduce(wg_wait);
hash_reduce(policy);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
TileOperator Clone() const;
private:
// Target GEMM instruction
enum class GemmInst : uint8_t { kMMA, kWGMMA, kUTCMMA, kMFMA };
GemmInst GetGemmInst(int block_size, Target target) const;
mutable bool completed_ = false;
};
class GemmPy : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(GemmPy, TileOperator, GemmPyNode);
TVM_DLL GemmPy(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_GEMM_PY_H_
\ No newline at end of file
...@@ -17,30 +17,6 @@ ...@@ -17,30 +17,6 @@
namespace tvm { namespace tvm {
namespace tl { namespace tl {
/**
* @brief Decomposes a positive integer into its prime factors.
*
* Returns the prime factorization of `x` as a vector of prime factors in
* non-decreasing order. If `x <= 1` the returned vector is empty.
*
* @param x Integer to factorize (expected non-negative; behavior: returns empty
* for values <= 1).
* @return std::vector<int> Prime factors of `x` (with repetition), e.g. 12 ->
* {2, 2, 3}.
*/
static std::vector<int> toPrimeFactors(int x) {
int i = 2;
std::vector<int> result;
while (x > 1) {
if (x % i == 0) {
x /= i;
result.push_back(i);
} else {
i++;
}
}
return result;
}
/** /**
* @brief Construct a GemmSP operator node from TL call arguments and a buffer * @brief Construct a GemmSP operator node from TL call arguments and a buffer
......
...@@ -14,11 +14,12 @@ ...@@ -14,11 +14,12 @@
#include <vector> #include <vector>
#include "../op/builtin.h" #include "../op/builtin.h"
#include "./ptx.h"
#include "arith/pattern_match.h" #include "arith/pattern_match.h"
#include "target/source/ptx.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
using namespace tvm::tl::codegen;
static std::string GetFP8Type(DataType type) { static std::string GetFP8Type(DataType type) {
std::stringstream stream; std::stringstream stream;
...@@ -1259,7 +1260,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1259,7 +1260,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
std::string asm_code = PrintMMAAssembly( std::string asm_code = PrintMMAAssembly(
shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias,
b_ref, b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate); b_ref, b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate);
this->PrintIndent();
this->stream << asm_code; this->stream << asm_code;
} else if (op->op.same_as(builtin::ptx_mma_sp())) { } else if (op->op.same_as(builtin::ptx_mma_sp())) {
// arg 0: shape: mXnXkX // arg 0: shape: mXnXkX
...@@ -1295,6 +1296,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1295,6 +1296,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
std::string metadata_offset = this->PrintExpr(op->args[13]); std::string metadata_offset = this->PrintExpr(op->args[13]);
std::string sparse_selector = this->PrintExpr(op->args[14]); std::string sparse_selector = this->PrintExpr(op->args[14]);
bool saturate = Downcast<Bool>(op->args[15])->value; bool saturate = Downcast<Bool>(op->args[15])->value;
this->PrintIndent();
std::string asm_code = PrintMMAAssembly( std::string asm_code = PrintMMAAssembly(
shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset, shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset,
b_ref, b_offset, c_ref, c_offset, metadata, metadata_offset, b_ref, b_offset, c_ref, c_offset, metadata, metadata_offset,
...@@ -1330,10 +1332,12 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1330,10 +1332,12 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
os << "}\n"; os << "}\n";
} else { } else {
std::string smem_elem_offset = this->PrintExpr(op->args[6]); std::string smem_elem_offset = this->PrintExpr(op->args[6]);
need_cast_smem_ptr_to_int_ = true; std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num);
this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, if (trans == 1)
local_elem_offset, smem_ptr, func_name += "_trans";
smem_elem_offset); this->PrintIndent();
this->stream << func_name << "(" << smem_ptr << " + " << smem_elem_offset
<< ", " << local_ptr << " + " << local_elem_offset << ");\n";
} }
} else if (op->op.same_as(builtin::mma_store())) { } else if (op->op.same_as(builtin::mma_store())) {
int m = Downcast<Integer>(op->args[0])->value; int m = Downcast<Integer>(op->args[0])->value;
......
This diff is collapsed.
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file ptx.h
* \brief Code generation with inlined PTX code.
*/
#ifndef TVM_TL_TARGET_SOURCE_PTX_H_
#define TVM_TL_TARGET_SOURCE_PTX_H_
#include <tvm/runtime/logging.h>
#include <string>
#include <tuple>
namespace tvm::tl {
namespace codegen {
/*!
* \brief Print MMA assembly string given parameters.
* \param shape The shape string mMnNkK
* \param A_layout The layout of multiplicand A, can be either "row" or "col".
* \param B_layout The layout of multiplicand B, can be either "row" or "col".
* \param A_dtype The data type of multiplicand A.
* \param B_dtype The data type of multiplicand B.
* \param C_dtype The data type of multiplicand C.
* \param a_ptr Pointer to buffer A.
* \param a_offset The offset of element in A.
* \param b_ptr Pointer to buffer B.
* \param b_offset The offset of element in B.
* \param c_ptr Pointer to buffer C.
* \param c_offset The offset of element in C.
* \param metadata Pointer to metadata buffer (only used for sparse mma).
* \param metadata_offset The offset of element in metadata.
* \param sparsity_selector The sparsity selector in sparse mma.
* \param bit_op The bit operator used in 1-bit mma, can be either "xor" or
* "and". \param sparse Whether it's sparse mma or not. \param saturate Whether
* saturate output or not.
*/
std::string
PrintMMAAssembly(const std::string &shape, const std::string &A_layout,
const std::string &B_layout, const std::string &A_dtype,
const std::string &B_dtype, const std::string &C_dtype,
const std::string &a_ptr, const std::string &a_offset,
const std::string &b_ptr, const std::string &b_offset,
const std::string &c_ptr, const std::string &c_offset,
const std::string &metadata,
const std::string &metadata_offset,
const std::string &sparsity_selector,
const std::string &bit_op, bool sparse, bool saturate);
/*!
* \brief Print ldmatrix assembly string given parameters.
* \param trans: whether the matrix is loaded in column major format or not.
* \param num: number of matrices to load.
* \param type: The data type in the matrix, .b16 is the only accepted data
* type. \param local_ptr: pointer to local buffer. \param local_elem_offset:
* The offset of the element to store in the local buffer. \param smem_ptr:
* pointer to the shared memory buffer to load. \param smem_elem_offset: The
* offset of the start element of the row to load in shared memory.
*/
std::string PrintLoadMatrixAssembly(bool trans, int num,
const std::string &type,
const std::string &local_ptr,
const std::string &local_elem_offset,
const std::string &smem_ptr,
const std::string &smem_elem_offset);
/*!
* \brief Print ptx cp.async assembly string given parameters.
* \param shared_ptr: The pointer to the destination shared memory.
* \param shared_elem_offset: The offset into the shared memory.
* \param global_ptr: The pointer to the global memory.
* \param global_elem_offset: The offset into the global memory.
* \param bytes: The number of bytes to copy, valid values are 4, 8, and 16.
*/
std::string PrintCpAsyncAssembly(const std::string &shared_ptr,
const std::string &shared_elem_offset,
const std::string &global_ptr,
const std::string &global_elem_offset,
const std::string &bytes);
/*!
* \brief Print predicated ptx cp.async assembly string given parameters.
* \param shared_ptr: The pointer to the destination shared memory.
* \param shared_elem_offset: The offset into the shared memory.
* \param global_ptr: The pointer to the global memory.
* \param global_elem_offset: The offset into the global memory.
* \param bytes: The number of bytes to copy, valid values are 4, 8, and 16.
* \param predicate_value: The value of predicate `@p`.
*/
std::string PrintPredicatedCpAsyncAssembly(
const std::string &shared_ptr, const std::string &shared_elem_offset,
const std::string &global_ptr, const std::string &global_elem_offset,
const std::string &bytes, const std::string &predicate_value);
/*!
* \brief Print ptx async copy from global to shared memory using cp.async.bulk
* \param shared_ptr: The pointer to the destination shared memory.
* \param shared_elem_offset: The offset into the shared memory.
* \param global_ptr: The pointer to the global memory.
* \param global_elem_offset: The offset into the global memory.
* \param bytes: The number of bytes to copy.
* \param barrier: The name of the barrier in shared memory.
*/
std::string PrintCpAsyncBulkAsm(const std::string &shared_ptr,
const std::string &shared_elem_offset,
const std::string &global_ptr,
const std::string &global_elem_offset,
const std::string &bytes,
const std::string &barrier);
/*!
* \brief Print ptx async copy barrier using cp.async.mbarrier.arrive
* \param barrier: The name of the barrier in shared memory.
*/
std::string PrintCpAsyncBarrierAsm(const std::string &barrier);
/*!
* \brief Print ptx barrier initialization of thread count using mbarrier.init
* \param barrier: The name of the barrier in shared memory.
* \param thread_count: The number of threads expected to arrive at the barrier.
*/
std::string PrintInitBarrierThreadCountAsm(const std::string &barrier,
const std::string &thread_count);
/*!
* \brief Print ptx barrier arrival using mbarrier.arrive
* \param barrier: The name of the barrier in shared memory.
*/
std::string PrintArriveBarrierAsm(const std::string &barrier);
/*!
* \brief Print ptx barrier arrival with expect tx operation using
* mbarrier.arrive.expect_tx \param barrier: The name of the barrier in shared
* memory. \param byte_count: Increases the tx count of the mbarrier object to
* track completion of addtional async transactions.
*/
std::string PrintArriveBarrierExpectTxAsm(const std::string &barrier,
const std::string &byte_count);
/*!
* \brief Print ptx barrier wait using mbarrier.try_wait
* \param barrier: The name of the barrier in shared memory.
*/
std::string PrintWaitBarrierAsm(const std::string &barrier);
} // namespace codegen
} // namespace tvm::tl
#endif // TVM_TL_TARGET_SOURCE_PTX_H_
...@@ -18,11 +18,11 @@ bool TargetIsRocm(Target target) { ...@@ -18,11 +18,11 @@ bool TargetIsRocm(Target target) {
int GetArchInt(Target target) { int GetArchInt(Target target) {
auto s = target->GetAttr<String>("arch"); auto s = target->GetAttr<String>("arch");
ICHECK(s.defined()); ICHECK(s.defined());
const char *arch_str = s.value().c_str(); const std::string arch_str = s.value();
ICHECK_EQ(arch_str[0], 's'); ICHECK(arch_str.size() >= 3);
ICHECK_EQ(arch_str[1], 'm'); ICHECK_EQ(arch_str.compare(0, 3, "sm_"), 0)
ICHECK_EQ(arch_str[2], '_'); << "arch string must start with sm_";
return atoi(&arch_str[3]); return std::stoi(arch_str.substr(3));
} }
bool TargetIsVolta(Target target) { bool TargetIsVolta(Target target) {
...@@ -118,5 +118,36 @@ int TargetGetWarpSize(Target target) { ...@@ -118,5 +118,36 @@ int TargetGetWarpSize(Target target) {
return res; return res;
} }
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("tl.TargetIsCuda",
[](Target target) { return TargetIsCuda(target); })
.def("tl.TargetIsRocm",
[](Target target) { return TargetIsRocm(target); })
.def("tl.TargetIsVolta",
[](Target target) { return TargetIsVolta(target); })
.def("tl.TargetIsTuring",
[](Target target) { return TargetIsTuring(target); })
.def("tl.TargetIsAmpere",
[](Target target) { return TargetIsAmpere(target); })
.def("tl.TargetIsHopper",
[](Target target) { return TargetIsHopper(target); })
.def("tl.TargetIsSM120",
[](Target target) { return TargetIsSM120(target); })
.def("tl.TargetIsCDNA",
[](Target target) { return TargetIsCDNA(target); })
.def("tl.TargetHasAsyncCopy",
[](Target target) { return TargetHasAsyncCopy(target); })
.def("tl.TargetHasLdmatrix",
[](Target target) { return TargetHasLdmatrix(target); })
.def("tl.TargetHasStmatrix",
[](Target target) { return TargetHasStmatrix(target); })
.def("tl.TargetHasBulkCopy",
[](Target target) { return TargetHasBulkCopy(target); })
.def("tl.TargetGetWarpSize",
[](Target target) { return TargetGetWarpSize(target); });
});
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -34,11 +34,11 @@ namespace tl { ...@@ -34,11 +34,11 @@ namespace tl {
using namespace tir; using namespace tir;
class FrontendLegalizer : public arith::IRMutatorWithAnalyzer { class LetInliner : public arith::IRMutatorWithAnalyzer {
public: public:
static PrimFunc Substitute(PrimFunc f) { static PrimFunc Substitute(PrimFunc f) {
arith::Analyzer analyzer; arith::Analyzer analyzer;
FrontendLegalizer substituter(&analyzer); LetInliner substituter(&analyzer);
PrimFuncNode *fptr = f.CopyOnWrite(); PrimFuncNode *fptr = f.CopyOnWrite();
fptr->body = substituter.VisitStmt(f->body); fptr->body = substituter.VisitStmt(f->body);
return f; return f;
...@@ -82,16 +82,16 @@ private: ...@@ -82,16 +82,16 @@ private:
using namespace tir::transform; using namespace tir::transform;
Pass FrontendLegalize() { Pass LetInline() {
auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
return FrontendLegalizer::Substitute(std::move(f)); return LetInliner::Substitute(std::move(f));
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.FrontendLegalize", {}); return CreatePrimFuncPass(pass_func, 0, "tl.LetInline", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.FrontendLegalize", FrontendLegalize); refl::GlobalDef().def("tl.transform.LetInline", LetInline);
}); });
} // namespace tl } // namespace tl
......
...@@ -248,7 +248,6 @@ public: ...@@ -248,7 +248,6 @@ public:
buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions)); buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions));
} }
} }
ordered_stmts_.resize(pipeline_info_.size()); ordered_stmts_.resize(pipeline_info_.size());
for (const auto &[block, anno] : pipeline_info_) { for (const auto &[block, anno] : pipeline_info_) {
ordered_stmts_.Set(anno.order, block); ordered_stmts_.Set(anno.order, block);
...@@ -675,6 +674,7 @@ private: ...@@ -675,6 +674,7 @@ private:
} }
new_block = Downcast<Block>(Substitute( new_block = Downcast<Block>(Substitute(
new_block, {{pipeline_loop_->loop_var, normalized_access_index}})); new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
if (pipeline_info_[block].async) { if (pipeline_info_[block].async) {
auto &local_state = async_states_local[stage]; auto &local_state = async_states_local[stage];
local_state.producer_head = normalized_access_index; local_state.producer_head = normalized_access_index;
...@@ -951,6 +951,12 @@ private: ...@@ -951,6 +951,12 @@ private:
Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op)); Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
Array<Array<BufferRegion>> access =
GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
BlockNode *n = block.CopyOnWrite();
n->reads = access[0];
n->writes = access[1];
for (const auto &buffer : op->alloc_buffers) { for (const auto &buffer : op->alloc_buffers) {
buffer_data_to_buffer_.erase(buffer->data); buffer_data_to_buffer_.erase(buffer->data);
} }
......
...@@ -303,26 +303,27 @@ private: ...@@ -303,26 +303,27 @@ private:
} else if (access_ptr_call->op.same_as(builtin::address_of())) { } else if (access_ptr_call->op.same_as(builtin::address_of())) {
BufferLoad load = Downcast<BufferLoad>(access_ptr_call->args[0]); BufferLoad load = Downcast<BufferLoad>(access_ptr_call->args[0]);
Array<PrimExpr> indices = load->indices; Array<PrimExpr> indices = load->indices;
Array<PrimExpr> shape = load->buffer->shape; Array<PrimExpr> old_shape = load->buffer->shape;
CHECK_EQ(indices.size(), shape.size()) CHECK_EQ(indices.size(), old_shape.size())
<< "Indices size and shape size must match for general N-dimensional " << "Indices size and shape size must match for general N-dimensional "
"buffer " "buffer "
<< "but got indices size: " << indices.size() << "but got indices size: " << indices.size()
<< " and shape size: " << shape.size(); << " and shape size: " << old_shape.size();
PrimExpr elem_offset = 0; PrimExpr elem_offset = 0;
PrimExpr stride = 1; PrimExpr stride = 1;
for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) { for (int i = static_cast<int>(old_shape.size()) - 1; i >= 0; --i) {
elem_offset += indices[i] * stride; elem_offset += indices[i] * stride;
stride *= shape[i]; stride *= old_shape[i];
} }
PrimExpr smem_offset = PrimExpr smem_offset =
elem_offset + (offset.defined() ? offset.value() : 0); elem_offset + (offset.defined() ? offset.value() : 0);
auto new_buffer = buffer_remap_[load->buffer]; auto new_buffer = buffer_remap_[load->buffer];
auto new_shape = new_buffer->shape;
auto buffer_map_iter = auto buffer_map_iter =
buffer_map_.find(Downcast<Var>(load->buffer->data)); buffer_map_.find(Downcast<Var>(load->buffer->data));
...@@ -337,26 +338,27 @@ private: ...@@ -337,26 +338,27 @@ private:
Array<PrimExpr> multi_dim_indices; Array<PrimExpr> multi_dim_indices;
PrimExpr remaining_offset = smem_offset; PrimExpr remaining_offset = smem_offset;
for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) { for (int i = static_cast<int>(old_shape.size()) - 1; i >= 0; --i) {
multi_dim_indices.insert(multi_dim_indices.begin(), multi_dim_indices.insert(multi_dim_indices.begin(),
floormod(remaining_offset, shape[i])); floormod(remaining_offset, old_shape[i]));
remaining_offset = floordiv(remaining_offset, shape[i]); remaining_offset = floordiv(remaining_offset, old_shape[i]);
} }
auto forward_indices = auto forward_indices =
layout_map_[load->buffer]->Forward(multi_dim_indices); layout_map_[load->buffer]->Forward(multi_dim_indices);
PrimExpr new_offset = 0; PrimExpr new_offset = 0;
PrimExpr stride_offset = 1; PrimExpr stride_offset = 1;
for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) { for (int i = static_cast<int>(new_shape.size()) - 1; i >= 0; --i) {
new_offset += forward_indices[i] * stride_offset; new_offset += forward_indices[i] * stride_offset;
stride_offset *= shape[i]; stride_offset *= new_shape[i];
} }
new_offset = analyzer_->Simplify(new_offset); new_offset = analyzer_->Simplify(new_offset);
Array<PrimExpr> new_indices; Array<PrimExpr> new_indices;
for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) { for (int i = static_cast<int>(new_shape.size()) - 1; i >= 0; --i) {
new_indices.insert(new_indices.begin(), floormod(new_offset, shape[i])); new_indices.insert(new_indices.begin(),
new_offset = floordiv(new_offset, shape[i]); floormod(new_offset, new_shape[i]));
new_offset = floordiv(new_offset, new_shape[i]);
} }
auto new_access_ptr = access_ptr_call.CopyOnWrite(); auto new_access_ptr = access_ptr_call.CopyOnWrite();
...@@ -397,7 +399,6 @@ private: ...@@ -397,7 +399,6 @@ private:
LOG(FATAL) << "Invalid access ptr for permuted layout: " << access_ptr; LOG(FATAL) << "Invalid access ptr for permuted layout: " << access_ptr;
} }
BufferLoad load = Downcast<BufferLoad>(address_of_call->args[0]); BufferLoad load = Downcast<BufferLoad>(address_of_call->args[0]);
if (buffer_remap_.count(load->buffer)) { if (buffer_remap_.count(load->buffer)) {
auto new_access_ptr = auto new_access_ptr =
HandleAccessPtrAndOffset(access_ptr, smem_offset, call->dtype); HandleAccessPtrAndOffset(access_ptr, smem_offset, call->dtype);
...@@ -494,9 +495,7 @@ private: ...@@ -494,9 +495,7 @@ private:
* visitor processing. * visitor processing.
*/ */
Stmt VisitStmt_(const EvaluateNode *op) final { Stmt VisitStmt_(const EvaluateNode *op) final {
// LOG(INFO) << "evaluate node: " << op->value;
const CallNode *call = op->value.as<CallNode>(); const CallNode *call = op->value.as<CallNode>();
// LOG(INFO) << "call: " << call->op;
// Do not analysis the call node to the global function. // Do not analysis the call node to the global function.
if (call && call->op.as<GlobalVarNode>()) if (call && call->op.as<GlobalVarNode>())
return Downcast<Evaluate>(IRMutatorWithAnalyzer::VisitStmt_(op)); return Downcast<Evaluate>(IRMutatorWithAnalyzer::VisitStmt_(op));
......
...@@ -44,13 +44,14 @@ def matmul( ...@@ -44,13 +44,14 @@ def matmul(
T.copy(B[bx * block_N, k * block_K], B_shared) T.copy(B[bx * block_N, k * block_K], B_shared)
else: else:
T.copy(B[k * block_K, bx * block_N], B_shared) T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) T.gemm_v2(A_shared, B_shared, C_local, trans_A, trans_B)
# T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N]) T.copy(C_local, C[by * block_M, bx * block_N])
return main return main
def run_gemm( def run_gemm_ss(
M, M,
N, N,
K, K,
...@@ -88,7 +89,8 @@ def run_gemm( ...@@ -88,7 +89,8 @@ def run_gemm(
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) })
profiler = kernel.get_profiler()
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
def ref_program(A, B): def ref_program(A, B):
import torch import torch
...@@ -104,11 +106,30 @@ def run_gemm( ...@@ -104,11 +106,30 @@ def run_gemm(
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_gemm(): def test_gemm_ss():
# More test case can be found in kernel/test_tilelang_kernel_gemm.py # More test case can be found in kernel/test_tilelang_kernel_gemm.py
# GEMM tests for float16 # GEMM tests for float16
run_gemm(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, run_gemm_ss(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 2)
2) # f16f16f16_nn run_gemm_ss(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 2)
run_gemm_ss(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 2)
run_gemm_ss(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 2)
# n8 test
run_gemm_ss(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128)
# int8 test
run_gemm_ss(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2)
run_gemm_ss(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2)
run_gemm_ss(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2)
run_gemm_ss(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2)
# float8 tests
run_gemm_ss(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2)
# tfloat32 test
run_gemm_ss(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_ss(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_ss(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_ss(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
def matmul_rs( def matmul_rs(
...@@ -146,18 +167,20 @@ def matmul_rs( ...@@ -146,18 +167,20 @@ def matmul_rs(
A_frag = T.alloc_fragment(A_frag_shape, in_dtype) A_frag = T.alloc_fragment(A_frag_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local) T.clear(C_local)
T.annotate_layout({
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
})
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A: if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared) T.copy(A[k * block_K, by * block_M], A_shared)
T.copy(A_shared, A_frag)
else: else:
T.copy(A[by * block_M, k * block_K], A_shared) T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(A_shared, A_frag)
if trans_B: if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared) T.copy(B[bx * block_N, k * block_K], B_shared)
else: else:
T.copy(B[k * block_K, bx * block_N], B_shared) T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_frag, B_shared, C_local, trans_A, trans_B) T.copy(A_shared, A_frag)
T.gemm_v2(A_frag, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N]) T.copy(C_local, C[by * block_M, bx * block_N])
return main return main
...@@ -201,7 +224,7 @@ def run_gemm_rs( ...@@ -201,7 +224,7 @@ def run_gemm_rs(
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) })
profiler = kernel.get_profiler() profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
def ref_program(A, B): def ref_program(A, B):
import torch import torch
...@@ -221,6 +244,299 @@ def test_gemm_rs(): ...@@ -221,6 +244,299 @@ def test_gemm_rs():
# GEMM tests for float16 # GEMM tests for float16
run_gemm_rs(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) run_gemm_rs(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_gemm_rs(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2) run_gemm_rs(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2)
run_gemm_rs(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_gemm_rs(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2)
# n8 tests
run_gemm_rs(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128)
# int8 tests
run_gemm_rs(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2)
run_gemm_rs(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2)
run_gemm_rs(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2)
run_gemm_rs(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2)
# float8 tests
run_gemm_rs(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2)
# float32 tests
run_gemm_rs(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_rs(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_rs(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_rs(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
def matmul_sr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
B_frag_shape = B_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
B_frag = T.alloc_fragment(B_frag_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
})
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.copy(B_shared, B_frag)
T.gemm_v2(A_shared, B_frag, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_sr(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul_sr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_gemm_sr():
# GEMM tests for float16
run_gemm_sr(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_gemm_sr(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2)
run_gemm_sr(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_gemm_sr(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2)
# n8 tests
run_gemm_sr(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128)
# int8 tests
run_gemm_sr(128, 128, 32, False, True, "int8", "int8", "int32", 128, 128, 32, 2)
run_gemm_sr(128, 128, 32, False, False, "int8", "int8", "int32", 128, 128, 32, 2)
run_gemm_sr(128, 128, 32, True, False, "int8", "int8", "int32", 128, 128, 32, 2)
run_gemm_sr(128, 128, 32, True, True, "int8", "int8", "int32", 128, 128, 32, 2)
# float8 tests
run_gemm_sr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2)
# float32 tests
run_gemm_sr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_sr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_sr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_sr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
def matmul_rr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
A_frag_shape = A_shared_shape
B_frag_shape = B_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
A_frag = T.alloc_fragment(A_frag_shape, in_dtype)
B_frag = T.alloc_fragment(B_frag_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
T.annotate_layout({
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
})
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.copy(A_shared, A_frag)
T.copy(B_shared, B_frag)
T.gemm_v2(A_frag, B_frag, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_rr(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul_rr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_gemm_rr():
# GEMM tests for float16
run_gemm_rr(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_gemm_rr(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2)
run_gemm_rr(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_gemm_rr(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2)
run_gemm_rr(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2)
# n8 tests
run_gemm_rr(128, 8, 128, False, True, "float16", "float16", "float16", 128, 8, 32, 2)
run_gemm_rr(128, 8, 128, False, True, "int8", "int8", "int32", 128, 8, 32, 2)
# int8 tests
run_gemm_rr(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2)
run_gemm_rr(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2)
run_gemm_rr(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2)
run_gemm_rr(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2)
# float8 tests
run_gemm_rr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2)
# float32 tests
run_gemm_rr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_rr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_rr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_rr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -7,7 +7,7 @@ import tilelang.testing ...@@ -7,7 +7,7 @@ import tilelang.testing
def _check(original, transformed): def _check(original, transformed):
func = original func = original
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
mod = tl.transform.FrontendLegalize()(mod) mod = tl.transform.LetInline()(mod)
tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"),
True) True)
......
...@@ -106,3 +106,5 @@ from .version import __version__ # noqa: F401 ...@@ -106,3 +106,5 @@ from .version import __version__ # noqa: F401
from .math import * # noqa: F403 from .math import * # noqa: F403
from . import ir # noqa: F401 from . import ir # noqa: F401
from . import tileop # noqa: F401
...@@ -85,8 +85,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: ...@@ -85,8 +85,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
""" """
mod = tir.transform.BindTarget(target)(mod) mod = tir.transform.BindTarget(target)(mod)
# Legalize the frontend IR to make it compatible with TVM # Inline let expressions and statements
mod = tilelang.transform.FrontendLegalize()(mod) mod = tilelang.transform.LetInline()(mod)
# Inject assumes to speedup tvm prover # Inject assumes to speedup tvm prover
mod = tilelang.transform.InjectAssumes()(mod) mod = tilelang.transform.InjectAssumes()(mod)
# Simplify the IR expressions # Simplify the IR expressions
......
...@@ -3,27 +3,27 @@ from tvm import arith, DataType ...@@ -3,27 +3,27 @@ from tvm import arith, DataType
import tilelang.language as T import tilelang.language as T
def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id): def ldmatrix_32x4_to_shared_16x8_layout_a(thread_id, local_id):
row = thread_id % 16 row = thread_id % 16
col = 8 * (thread_id // 16) + local_id % 8 col = (thread_id // 16) * 4 + local_id % 4
return row, col return row, col
def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id): def ldmatrix_32x4_to_shared_16x8_layout_b(thread_id, local_id):
row = 8 * (thread_id // 16) + (thread_id % 8) row = (thread_id // 16) * 8 + (thread_id % 8)
col = 8 * ((thread_id % 16) // 8) + local_id % 8 col = ((thread_id % 16) // 8) * 4 + local_id % 4
return row, col return row, col
def ldmatrix_16x32_to_shared_16x32_layout_a(thread_id, local_id): def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id):
row = thread_id % 16 row = thread_id % 16
col = 16 * (thread_id // 16) + local_id % 16 col = 8 * (thread_id // 16) + local_id % 8
return row, col return row, col
def ldmatrix_16x32_to_shared_16x32_layout_b(thread_id, local_id): def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id):
row = 8 * (thread_id // 16) + (thread_id % 8) row = 8 * (thread_id // 16) + (thread_id % 8)
col = 16 * ((thread_id % 16) // 8) + local_id % 16 col = 8 * ((thread_id % 16) // 8) + local_id % 8
return row, col return row, col
...@@ -47,28 +47,78 @@ def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id): ...@@ -47,28 +47,78 @@ def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id):
# sr represents spatial + reduction layout # sr represents spatial + reduction layout
# the first axis is spatial while the second axis is reduction # the first axis is spatial while the second axis is reduction
def shared_16x16_to_mma_32x8_layout_sr(i, j): # mma.sync matrix A layout, if wanna trans, please apply map_indices
def shared_16x8_to_mma_a_32x4_layout(i, j):
thread_id = 4 * (i % 8) + (j % 4)
return thread_id, 2 * (j // 4) + (i // 8)
def shared_16x8_to_mma_a_32x4_layout_trans(i, j):
return shared_16x8_to_mma_a_32x4_layout(j, i)
# mma.sync matrix B layout, if wanna trans, please apply map_indices
def shared_16x8_to_mma_b_32x4_layout(i, j):
thread_id = 4 * (i % 8) + (j % 4)
return thread_id, 2 * (i // 8) + (j // 4)
def shared_16x8_to_mma_b_32x4_layout_trans(i, j):
return shared_16x8_to_mma_b_32x4_layout(j, i)
shared_16x8_to_mma_32x4_layout_sr_a = shared_16x8_to_mma_a_32x4_layout
shared_16x8_to_mma_32x4_layout_sr_b = shared_16x8_to_mma_b_32x4_layout
shared_16x8_to_mma_32x4_layout_rs_a = shared_16x8_to_mma_a_32x4_layout_trans
shared_16x8_to_mma_32x4_layout_rs_b = shared_16x8_to_mma_b_32x4_layout_trans
def shared_16x16_to_mma_a_32x8_layout(i, j):
thread_id = 4 * (i % 8) + (j % 8) // 2 thread_id = 4 * (i % 8) + (j % 8) // 2
return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2) return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2)
def shared_16x16_to_mma_32x8_layout_rs(i, j): def shared_16x16_to_mma_a_32x8_layout_trans(i, j):
thread_id = 4 * (j % 8) + (i % 8) // 2 return shared_16x16_to_mma_a_32x8_layout(j, i)
return thread_id, 4 * (i // 8) + (j // 8) * 2 + (i % 2)
def shared_16x16_to_mma_b_32x8_layout(i, j):
thread_id = 4 * (i % 8) + (j % 8) // 2
return thread_id, 4 * (i // 8) + (j // 8) * 2 + (j % 2)
def shared_16x16_to_mma_b_32x8_layout_trans(i, j):
return shared_16x16_to_mma_b_32x8_layout(j, i)
shared_16x16_to_mma_32x8_layout = shared_16x16_to_mma_32x8_layout_sr shared_16x16_to_mma_32x8_layout_sr_a = shared_16x16_to_mma_a_32x8_layout
shared_16x16_to_mma_32x8_layout_trans = shared_16x16_to_mma_32x8_layout_rs shared_16x16_to_mma_32x8_layout_sr_b = shared_16x16_to_mma_b_32x8_layout
shared_16x16_to_mma_32x8_layout_rs_a = shared_16x16_to_mma_a_32x8_layout_trans
shared_16x16_to_mma_32x8_layout_rs_b = shared_16x16_to_mma_b_32x8_layout_trans
def shared_16x32_to_mma_32x16_layout(i, j): def shared_16x32_to_mma_a_32x16_layout(i, j):
thread_id = 4 * (i % 8) + (j % 16) // 4 thread_id = 4 * (i % 8) + (j % 16) // 4
return thread_id, 8 * (j // 16) + (i // 8) * 4 + j % 4 return thread_id, 8 * (j // 16) + (i // 8) * 4 + j % 4
def shared_32x16_to_mma_32x16_layout(i, j): def shared_32x16_to_mma_a_32x16_layout_trans(i, j):
thread_id = (i % 16) // 4 + 4 * (j % 8) return shared_16x32_to_mma_a_32x16_layout(j, i)
return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4
def shared_16x32_to_mma_b_32x16_layout(i, j):
thread_id = 4 * (i % 8) + (j % 16) // 4
return thread_id, 8 * (i // 8) + (j // 16) * 4 + j % 4
def shared_32x16_to_mma_b_32x16_layout_trans(i, j):
return shared_16x32_to_mma_b_32x16_layout(j, i)
shared_16x32_to_mma_32x16_layout_sr_a = shared_16x32_to_mma_a_32x16_layout
shared_16x32_to_mma_32x16_layout_sr_b = shared_16x32_to_mma_b_32x16_layout
shared_16x32_to_mma_32x16_layout_rs_a = shared_32x16_to_mma_a_32x16_layout_trans
shared_16x32_to_mma_32x16_layout_rs_b = shared_32x16_to_mma_b_32x16_layout_trans
def mma_32x8_to_shared_16x16_layout(thread_id, local_id): def mma_32x8_to_shared_16x16_layout(thread_id, local_id):
...@@ -77,6 +127,30 @@ def mma_32x8_to_shared_16x16_layout(thread_id, local_id): ...@@ -77,6 +127,30 @@ def mma_32x8_to_shared_16x16_layout(thread_id, local_id):
return row, col return row, col
def mma_load_a_32x4_to_shared_16x8_layout(thread_id, local_id):
row = 8 * (local_id % 2) + (thread_id // 4)
col = 4 * (local_id // 2) + (thread_id % 4)
return row, col
def mma_load_b_32x4_to_shared_16x8_layout(thread_id, local_id):
row = 8 * (local_id // 2) + (thread_id // 4)
col = 4 * (local_id % 2) + (thread_id % 4)
return row, col
def mma_load_a_32x16_to_shared_16x32_layout(thread_id, local_id):
row = 8 * (local_id % 8 // 4) + (thread_id // 4)
col = 16 * (local_id // 8) + (thread_id % 4) * 4 + (local_id % 4)
return row, col
def mma_load_b_32x16_to_shared_16x32_layout(thread_id, local_id):
row = 8 * (local_id // 8) + (thread_id // 4)
col = 16 * (local_id % 8 // 4) + (thread_id % 4) * 4 + (local_id % 4)
return row, col
def shared_16x16_to_mma_32x8_smoothlayout(i, j): def shared_16x16_to_mma_32x8_smoothlayout(i, j):
return (i * 2 + j // 8, j % 8) return (i * 2 + j // 8, j % 8)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment