"examples/vscode:/vscode.git/clone" did not exist on "a41e4c506bea0179ac6e556620c7ed45cc4c5f29"
Commit 6ad73f6f authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Support auto index bitwidth casting (#517)

* [Refactor] Enhance GEMM Warp Partitioning Logic and Introduce Buffer Remapping (#516)

* Improved the warp partitioning logic in `Gemm::ComputeWarpPartition` to better accommodate various GEMM policies, including FullRow, FullCol, and Square, ensuring optimal performance based on matrix dimensions.
* Introduced a new `RemapBufferRewriter` class to handle buffer reference updates and padding annotations during statement transformations, enhancing memory access safety and clarity.
* Updated the `OptimizeForTarget` function to include a new step for configuring index bitwidth, improving the overall optimization process.
* Refactored existing code to utilize constants for warp sizes, enhancing maintainability and readability.
* Added checks to ensure correct warp allocation and padding map handling, improving robustness in memory management strategies.

* [Refactor] Update ConfigIndexBitwidthRewriter to Support Auto-Check Feature

* Modified the constructor of `ConfigIndexBitwidthRewriter` to include an `auto_check` parameter, allowing for dynamic bitwidth adjustments based on input conditions.
* Enhanced the `VisitExpr_` methods to apply the new auto-check logic, ensuring that integer types are upgraded to 64 bits when necessary, or to a specified index bitwidth otherwise.
* Updated the `ConfigIndexBitwidth` pass to determine the index bitwidth based on the presence of configuration, improving flexibility in handling different scenarios.

* Add dynamic matrix multiplication example and corresponding test

* Introduced `example_dynamic.py` to demonstrate dynamic matrix multiplication using TileLang and PyTorch, including a main function for execution and performance profiling.
* Added `test_example_dynamic.py` to validate the functionality of the dynamic matrix multiplication example.
* The example includes detailed parameter configurations and checks against PyTorch's implementation for correctness.

* lint fix

* Add get_num_sms function to retrieve the number of streaming multiprocessors on the CUDA device

* Implemented the `get_num_sms` function in `cuda_driver.py` to return the count of streaming multiprocessors for a specified CUDA device.
* Updated the `__init__.py` file to include the new function in the module exports.

* lint fix
parent 0d1eab57
...@@ -61,20 +61,79 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) { ...@@ -61,20 +61,79 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target, std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target,
bool maybe_hopper_wgmma) const { bool maybe_hopper_wgmma) const {
int m_warp = 1, n_warp = 1; int m_warp = 1, n_warp = 1;
constexpr int kMPerWarp = 16; // Rows processed by a single warp
constexpr int kNPerWarp = 8; // Columns processed by a single warp
bool allow_wgmma = TargetIsHopper(target) && maybe_hopper_wgmma && bool allow_wgmma = TargetIsHopper(target) && maybe_hopper_wgmma &&
(this->M >= 64) && (num_warps % 4 == 0); (this->M >= 64) && (num_warps % 4 == 0);
if (allow_wgmma) { if (allow_wgmma) {
ICHECK(num_warps % 4 == 0) << "Use Warp Group MMA requires 128*N threads."; ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";
if (this->policy == GemmWarpPolicy::kFullRow ||
this->policy == GemmWarpPolicy::kSquare) { constexpr int kGroup = 4; // Number of warps in a warp-group
m_warp = num_warps;
n_warp = 1; m_warp = kGroup; // Initially, only one warp-group on M dimension
n_warp = num_warps / m_warp; // Rest all on N dimension
if (this->policy == GemmWarpPolicy::kFullRow) {
// Try to put as many warp-groups as possible on M dimension
// (decreasing multiples of 4, ensuring divisibility by M)
for (int cand = num_warps; cand >= kGroup; cand -= kGroup) {
if (this->M % (cand * kMPerWarp) == 0) {
m_warp = cand;
n_warp = num_warps / m_warp;
break;
}
}
} else if (this->policy == GemmWarpPolicy::kFullCol) { } else if (this->policy == GemmWarpPolicy::kFullCol) {
m_warp = 1; // Try to use warps on N dimension; if N is not divisible, split excess
n_warp = num_warps; // groups to M
int cand_n = n_warp; // Initially assume all on N
if (this->N % (cand_n * kNPerWarp) != 0) { // N direction division fails
int max_n = this->N / kNPerWarp;
// Find a feasible n_warp from max possible downwards, ensuring
// num_warps/n_warp is multiple of 4
for (int n = std::min(cand_n, max_n); n >= 1; --n) {
if (num_warps % n == 0 && (num_warps / n) % kGroup == 0) {
n_warp = n;
m_warp = num_warps / n_warp;
break;
}
}
}
} else if (this->policy == GemmWarpPolicy::kSquare) {
// Exhaustive search, but m must be multiple of 4
int max_m = this->M / kMPerWarp;
int max_n = this->N / kNPerWarp;
float ideal = this->N > 0 ? static_cast<float>(this->M) / this->N : 1.f;
float best_score = std::numeric_limits<float>::max();
int best_m = kGroup, best_n = n_warp;
for (int m = kGroup; m <= num_warps && m <= max_m; m += kGroup) {
if (num_warps % m)
continue;
int n = num_warps / m;
if (n > max_n)
continue;
float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);
float n_per_warp = static_cast<float>(this->N) / (n * kNPerWarp);
float score = std::abs(m_per_warp / n_per_warp - ideal);
if (score < best_score) {
best_score = score;
best_m = m;
best_n = n;
}
}
m_warp = best_m;
n_warp = best_n;
} else { } else {
ICHECK(0) << "Unknown GemmWarpPolicy"; ICHECK(0) << "Unknown GemmWarpPolicy";
} }
ICHECK(m_warp * n_warp == num_warps)
<< "m_warp * n_warp must equal num_warps";
return {m_warp, n_warp}; return {m_warp, n_warp};
} }
...@@ -85,9 +144,9 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target, ...@@ -85,9 +144,9 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target,
// If M cannot be evenly divided by m_warp*16, try to split remaining warps // If M cannot be evenly divided by m_warp*16, try to split remaining warps
// to N // to N
if (this->M % (m_warp * 16) != 0) { if (this->M % (m_warp * kMPerWarp) != 0) {
// Calculate how many warps we can use for M // Calculate how many warps we can use for M
int max_m_warps = this->M / 16; int max_m_warps = this->M / kMPerWarp;
m_warp = max_m_warps; m_warp = max_m_warps;
// Use remaining warps for N // Use remaining warps for N
n_warp = num_warps / m_warp; n_warp = num_warps / m_warp;
...@@ -101,9 +160,9 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target, ...@@ -101,9 +160,9 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target,
// If N cannot be evenly divided by n_warp*8, try to split remaining warps // If N cannot be evenly divided by n_warp*8, try to split remaining warps
// to M // to M
if (this->N % (n_warp * 8) != 0) { if (this->N % (n_warp * kNPerWarp) != 0) {
// Calculate how many warps we can use for N // Calculate how many warps we can use for N
int max_n_warps = this->N / 8; int max_n_warps = this->N / kNPerWarp;
n_warp = max_n_warps; n_warp = max_n_warps;
// Use remaining warps for M // Use remaining warps for M
m_warp = num_warps / n_warp; m_warp = num_warps / n_warp;
...@@ -112,8 +171,10 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target, ...@@ -112,8 +171,10 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target,
} }
} else if (this->policy == GemmWarpPolicy::kSquare) { } else if (this->policy == GemmWarpPolicy::kSquare) {
// First calculate the maximum possible warps for each dimension // First calculate the maximum possible warps for each dimension
int max_m_warps = this->M / 16; // Each warp needs at least 16 elements in M int max_m_warps =
int max_n_warps = this->N / 8; // Each warp needs at least 8 elements in N this->M / kMPerWarp; // Each warp needs at least 16 elements in M
int max_n_warps =
this->N / kNPerWarp; // Each warp needs at least 8 elements in N
// Calculate the ideal ratio of M/N warps based on the matrix dimensions // Calculate the ideal ratio of M/N warps based on the matrix dimensions
float ideal_ratio = 1.0f; float ideal_ratio = 1.0f;
...@@ -139,8 +200,8 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target, ...@@ -139,8 +200,8 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target,
continue; continue;
// Calculate how balanced this partition is // Calculate how balanced this partition is
float m_per_warp = static_cast<float>(this->M) / (m * 16); float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);
float n_per_warp = static_cast<float>(this->N) / (n * 8); float n_per_warp = static_cast<float>(this->N) / (n * kNPerWarp);
float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio); float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);
if (balance < best_balance) { if (balance < best_balance) {
......
...@@ -21,14 +21,14 @@ ...@@ -21,14 +21,14 @@
* \file flatten_buffer.cc * \file flatten_buffer.cc
*/ */
#include "arith/ir_mutator_with_analyzer.h"
#include "tir/transforms/ir_utils.h"
#include <tvm/arith/iter_affine_map.h> #include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/analysis.h> #include <tvm/tir/analysis.h>
#include <tvm/tir/data_type_rewriter.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include "arith/ir_mutator_with_analyzer.h"
#include "tir/transforms/ir_utils.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -59,6 +59,43 @@ private: ...@@ -59,6 +59,43 @@ private:
using IRMutatorWithAnalyzer::VisitStmt; using IRMutatorWithAnalyzer::VisitStmt;
using IRMutatorWithAnalyzer::VisitStmt_; using IRMutatorWithAnalyzer::VisitStmt_;
class Int64Promoter : public tir::IndexDataTypeRewriter {
public:
using Parent = IndexDataTypeRewriter;
PrimExpr VisitExpr_(const VarNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), GetRef<Var>(op));
}
return GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const IntImmNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return IntImm(DataType::Int(64), op->value);
}
return GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const CastNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), op->value);
}
return GetRef<PrimExpr>(op);
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
// Force indices to be int64
auto node = Downcast<BufferStore>(Parent::VisitStmt_(op));
return std::move(node);
}
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto node = Downcast<BufferLoad>(Parent::VisitExpr_(op));
return std::move(node);
}
};
explicit BufferFlattener(arith::Analyzer *ana) : IRMutatorWithAnalyzer(ana) {} explicit BufferFlattener(arith::Analyzer *ana) : IRMutatorWithAnalyzer(ana) {}
Stmt VisitStmt_(const BlockNode *op) final { Stmt VisitStmt_(const BlockNode *op) final {
...@@ -239,7 +276,28 @@ private: ...@@ -239,7 +276,28 @@ private:
Array<PrimExpr> GetSimplifiedElemOffset(const Buffer &buffer, Array<PrimExpr> GetSimplifiedElemOffset(const Buffer &buffer,
const Array<PrimExpr> &indices) { const Array<PrimExpr> &indices) {
auto flattened_indices = buffer->ElemOffset(indices); auto flattened_indices = buffer->ElemOffset(indices);
return this->IterMapSimplifyWithContext(flattened_indices, false); Array<PrimExpr> safe_indices;
for (auto index : flattened_indices) {
auto int_bound = analyzer_->const_int_bound(index);
DataType dtype = index->dtype;
if (dtype.is_int() && dtype.bits() < 64) {
int64_t max_value = int_bound->max_value + 1;
int64_t min_value = int_bound->min_value;
const int64_t type_max = (1LL << (dtype.bits() - 1));
const int64_t type_min = -(1LL << (dtype.bits() - 1));
if (max_value >= type_max || min_value < type_min) {
Int64Promoter promoter;
for (auto &index : flattened_indices) {
safe_indices.push_back(promoter(index));
}
} else {
safe_indices.push_back(index);
}
} else {
safe_indices.push_back(index);
}
}
return this->IterMapSimplifyWithContext(safe_indices, false);
} }
template <typename Node> Node VisitBufferAccess(Node node) { template <typename Node> Node VisitBufferAccess(Node node) {
......
...@@ -315,7 +315,8 @@ private: ...@@ -315,7 +315,8 @@ private:
.value(); .value();
for (const auto &[var, padding] : map) { for (const auto &[var, padding] : map) {
ICHECK(buffer_data_to_buffer_.count(var)) ICHECK(buffer_data_to_buffer_.count(var))
<< "buffer " << var << " is not found in the block"; << "buffer " << var << " is not found in the block "
<< buffer_data_to_buffer_;
auto buffer = buffer_data_to_buffer_[var]; auto buffer = buffer_data_to_buffer_[var];
annotated_padding_map_.Set(buffer, padding); annotated_padding_map_.Set(buffer, padding);
} }
......
...@@ -70,6 +70,91 @@ static Buffer makeBufferWithLayout(const Buffer &buffer, const Layout &layout, ...@@ -70,6 +70,91 @@ static Buffer makeBufferWithLayout(const Buffer &buffer, const Layout &layout,
buffer->buffer_type); buffer->buffer_type);
} }
/*!
* \brief A class that rewrites buffer references in a statement based on a
* given buffer remapping.
*
* This class is used to update buffer references in a statement after buffer
* transformations have been applied. It specifically handles the remapping of
* padding annotations.
*/
class RemapBufferRewriter : public arith::IRMutatorWithAnalyzer {
public:
/*!
* \brief Substitute buffer references in a statement based on a given buffer
* remapping. \param stmt The statement to rewrite. \param buffer_remap A map
* from old buffers to new buffers. \return The rewritten statement.
*/
static Stmt Substitute(Stmt stmt, Map<Buffer, Buffer> buffer_remap) {
arith::Analyzer analyzer;
RemapBufferRewriter substituter(&analyzer);
substituter.buffer_remap_ = std::move(buffer_remap);
return substituter.VisitStmt(stmt);
}
private:
using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer;
Stmt VisitStmt_(const BlockNode *op) final {
if (op->annotations.count(attr::kPaddingMap)) {
return RewritePaddingMap(op);
}
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
/*!
* \brief Rewrite the padding map annotation of a block.
* \param op The block node to rewrite.
* \return The rewritten block.
*/
Stmt RewritePaddingMap(const BlockNode *op) {
auto padding_map =
op->annotations.Get(attr::kPaddingMap).as<Map<Var, PrimExpr>>().value();
Map<Var, Var> var_remap = CreateVarRemap();
Map<Var, PrimExpr> new_padding_map =
RemapPaddingMap(padding_map, var_remap);
auto block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op));
auto block_ptr = block.CopyOnWrite();
block_ptr->annotations.Set(attr::kPaddingMap, new_padding_map);
return block;
}
/*!
* \brief Create a mapping from old variables to new variables based on buffer
* remapping. \return A map from old variables to new variables.
*/
Map<Var, Var> CreateVarRemap() const {
Map<Var, Var> var_remap;
for (const auto &[buffer, buffer_remap] : buffer_remap_) {
var_remap.Set(buffer->data, buffer_remap->data);
}
return var_remap;
}
/*!
* \brief Remap the padding map using the variable remapping.
* \param padding_map The original padding map.
* \param var_remap The variable remapping.
* \return The remapped padding map.
*/
Map<Var, PrimExpr> RemapPaddingMap(const Map<Var, PrimExpr> &padding_map,
const Map<Var, Var> &var_remap) const {
Map<Var, PrimExpr> new_padding_map;
for (const auto &[var, padding] : padding_map) {
if (var_remap.count(var)) {
new_padding_map.Set(var_remap.at(var), padding);
} else {
new_padding_map.Set(var, padding);
}
}
return new_padding_map;
}
Map<Buffer, Buffer> buffer_remap_;
};
class LowerTileOpPass : arith::IRMutatorWithAnalyzer { class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
public: public:
static PrimFunc Substitute(PrimFunc f) { static PrimFunc Substitute(PrimFunc f) {
...@@ -85,6 +170,8 @@ public: ...@@ -85,6 +170,8 @@ public:
substituter.target_ = target.value(); substituter.target_ = target.value();
PrimFuncNode *fptr = f.CopyOnWrite(); PrimFuncNode *fptr = f.CopyOnWrite();
fptr->body = substituter.VisitStmt(f->body); fptr->body = substituter.VisitStmt(f->body);
fptr->body =
RemapBufferRewriter::Substitute(fptr->body, substituter.buffer_remap_);
return f; return f;
} }
......
...@@ -4,4 +4,5 @@ from .cuda_driver import ( ...@@ -4,4 +4,5 @@ from .cuda_driver import (
get_shared_memory_per_block, # noqa: F401 get_shared_memory_per_block, # noqa: F401
get_device_attribute, # noqa: F401 get_device_attribute, # noqa: F401
get_max_dynamic_shared_size_bytes, # noqa: F401 get_max_dynamic_shared_size_bytes, # noqa: F401
get_num_sms, # noqa: F401
) )
...@@ -162,3 +162,23 @@ def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes") ...@@ -162,3 +162,23 @@ def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes")
raise RuntimeError("Invalid format. Must be one of: bytes, kb, mb") raise RuntimeError("Invalid format. Must be one of: bytes, kb, mb")
else: else:
raise RuntimeError("Failed to get device properties.") raise RuntimeError("Failed to get device properties.")
def get_num_sms(device_id: int = 0) -> int:
"""
Get the number of streaming multiprocessors (SMs) on the CUDA device.
Args:
device_id (int, optional): The CUDA device ID. Defaults to 0.
Returns:
int: The number of SMs on the device.
Raises:
RuntimeError: If unable to get the device properties.
"""
prop = get_cuda_device_properties(device_id)
if prop:
return prop.multiProcessorCount
else:
raise RuntimeError("Failed to get device properties.")
...@@ -102,8 +102,9 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -102,8 +102,9 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.InjectFenceProxy()(mod) mod = tilelang.transform.InjectFenceProxy()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod) mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.FlattenBuffer()(mod)
mod = tir.transform.NarrowDataType(32)(mod) mod = tir.transform.NarrowDataType(32)(mod)
mod = tilelang.transform.ConfigIndexBitwidth()(mod)
mod = tilelang.transform.FlattenBuffer()(mod)
mod = tir.transform.Simplify()(mod) mod = tir.transform.Simplify()(mod)
mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod) mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod)
...@@ -131,7 +132,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -131,7 +132,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tir.transform.InferFragment()(mod) mod = tir.transform.InferFragment()(mod)
mod = tir.transform.LowerThreadAllreduce()(mod) mod = tir.transform.LowerThreadAllreduce()(mod)
mod = tilelang.transform.LowerHopperIntrin()(mod) mod = tilelang.transform.LowerHopperIntrin()(mod)
mod = tilelang.transform.ConfigIndexBitwidth()(mod)
mod = tilelang.transform.ThreadSync("shared")(mod) mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod) mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
mod = tilelang.transform.EliminateStorageSyncForMBarrier()(mod) mod = tilelang.transform.EliminateStorageSyncForMBarrier()(mod)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment