"git@developer.sourcefind.cn:yangql/composable_kernel.git" did not exist on "a9031464271a961b4b23d9cbf0e5d944dc8a78bf"
Commit 623edf4c authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Support auto synchronization for global memory access (#519)

* [Refactor] Enhance GEMM Warp Partitioning Logic and Introduce Buffer Remapping (#516)

* Improved the warp partitioning logic in `Gemm::ComputeWarpPartition` to better accommodate various GEMM policies, including FullRow, FullCol, and Square, ensuring optimal performance based on matrix dimensions.
* Introduced a new `RemapBufferRewriter` class to handle buffer reference updates and padding annotations during statement transformations, enhancing memory access safety and clarity.
* Updated the `OptimizeForTarget` function to include a new step for configuring index bitwidth, improving the overall optimization process.
* Refactored existing code to utilize constants for warp sizes, enhancing maintainability and readability.
* Added checks to ensure correct warp allocation and padding map handling, improving robustness in memory management strategies.

* [Refactor] Update ConfigIndexBitwidthRewriter to Support Auto-Check Feature

* Modified the constructor of `ConfigIndexBitwidthRewriter` to include an `auto_check` parameter, allowing for dynamic bitwidth adjustments based on input conditions.
* Enhanced the `VisitExpr_` methods to apply the new auto-check logic, ensuring that integer types are upgraded to 64 bits when necessary, or to a specified index bitwidth otherwise.
* Updated the `ConfigIndexBitwidth` pass to determine the index bitwidth based on the presence of configuration, improving flexibility in handling different scenarios.

* Add dynamic matrix multiplication example and corresponding test

* Introduced `example_dynamic.py` to demonstrate dynamic matrix multiplication using TileLang and PyTorch, including a main function for execution and performance profiling.
* Added `test_example_dynamic.py` to validate the functionality of the dynamic matrix multiplication example.
* The example includes detailed parameter configurations and checks against PyTorch's implementation for correctness.

* lint fix

* Add get_num_sms function to retrieve the number of streaming multiprocessors on the CUDA device

* Implemented the `get_num_sms` function in `cuda_driver.py` to return the count of streaming multiprocessors for a specified CUDA device.
* Updated the `__init__.py` file to include the new function in the module exports.

* lint fix

* Add global barrier state and expectation handling in CUDA code generation

* Introduced `vid_global_barrier_state_` and `vid_global_barrier_expect_` to manage global barrier synchronization in the CUDA code generator.
* Updated `Finish` method to declare the global barrier state if needed.
* Implemented handling for `EvaluateNode` to initialize the barrier expectation.
* Removed unnecessary extern declaration for the global barrier state in `PrintStorageSync` method.
* Enhanced CUDA FP8 type definitions for better alignment and structure.
parent 6ad73f6f
...@@ -51,6 +51,11 @@ static std::string GetFP8Type(DataType type) { ...@@ -51,6 +51,11 @@ static std::string GetFP8Type(DataType type) {
CodeGenTileLangCUDA::CodeGenTileLangCUDA() { CodeGenTileLangCUDA::CodeGenTileLangCUDA() {
restrict_keyword_ = "__restrict__"; restrict_keyword_ = "__restrict__";
vid_global_barrier_state_ =
name_supply_->FreshName(runtime::symbol::tvm_global_barrier_state);
vid_global_barrier_expect_ = name_supply_->FreshName("__barrier_expect");
ICHECK_EQ(vid_global_barrier_state_,
runtime::symbol::tvm_global_barrier_state);
} }
void CodeGenTileLangCUDA::PrintFuncPrefix(std::ostream &os) { void CodeGenTileLangCUDA::PrintFuncPrefix(std::ostream &os) {
...@@ -118,7 +123,13 @@ std::string CodeGenTileLangCUDA::Finish() { ...@@ -118,7 +123,13 @@ std::string CodeGenTileLangCUDA::Finish() {
decl_stream << "#include <tl_templates/cuda/ldsm.h>\n"; decl_stream << "#include <tl_templates/cuda/ldsm.h>\n";
decl_stream << "#include <tl_templates/cuda/threadblock_swizzle.h>\n"; decl_stream << "#include <tl_templates/cuda/threadblock_swizzle.h>\n";
decl_stream << "#include <tl_templates/cuda/debug.h>\n"; decl_stream << "#include <tl_templates/cuda/debug.h>\n";
if (need_global_barrier_) {
decl_stream << "__device__ __managed__ unsigned "
<< vid_global_barrier_state_ << " = 0;\n";
}
decl_stream << "\n"; decl_stream << "\n";
return CodeGenC::Finish(); return CodeGenC::Finish();
} }
...@@ -547,8 +558,6 @@ void CodeGenTileLangCUDA::PrintStorageSync(const CallNode *op) { ...@@ -547,8 +558,6 @@ void CodeGenTileLangCUDA::PrintStorageSync(const CallNode *op) {
} else if (sync == "global") { } else if (sync == "global") {
if (!need_global_barrier_) { if (!need_global_barrier_) {
need_global_barrier_ = true; need_global_barrier_ = true;
this->decl_stream << "extern \"C\" __device__ unsigned "
<< vid_global_barrier_state_ << ";\n";
} }
// global synchronizer // global synchronizer
std::string is_load = PrintExpr(op->args[1]); std::string is_load = PrintExpr(op->args[1]);
...@@ -1349,6 +1358,24 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) { ...@@ -1349,6 +1358,24 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
this->PrintStmt(op->body); this->PrintStmt(op->body);
} }
void CodeGenTileLangCUDA::VisitStmt_(const EvaluateNode *op) {
if (is_const_int(op->value))
return;
const CallNode *call = op->value.as<CallNode>();
if (call && call->op.same_as(builtin::tvm_global_barrier_kinit())) {
PrintIndent();
stream << "__shared__ unsigned " << vid_global_barrier_expect_ << ";\n";
PrintIndent();
stream << "if (threadIdx.x == 0) {\n";
PrintIndent();
stream << " " << vid_global_barrier_expect_ << " = 0;\n";
PrintIndent();
stream << "}\n";
} else {
CodeGenC::VisitStmt_(op);
}
}
void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) { void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) {
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value); int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
CHECK_LE(lanes, 4) << "Translate Ramp Node " << GetRef<Ramp>(op) << " with " CHECK_LE(lanes, 4) << "Translate Ramp Node " << GetRef<Ramp>(op) << " with "
......
...@@ -47,6 +47,7 @@ public: ...@@ -47,6 +47,7 @@ public:
void VisitExpr_(const FloatImmNode *op, std::ostream &os) final; void VisitExpr_(const FloatImmNode *op, std::ostream &os) final;
void VisitExpr_(const CallNode *op, std::ostream &os) final; void VisitExpr_(const CallNode *op, std::ostream &os) final;
void VisitExpr_(const CastNode *op, std::ostream &os) final; void VisitExpr_(const CastNode *op, std::ostream &os) final;
void VisitStmt_(const EvaluateNode *op) final;
void VisitStmt_(const AllocateNode *op) final; void VisitStmt_(const AllocateNode *op) final;
void VisitStmt_(const AttrStmtNode *op) final; void VisitStmt_(const AttrStmtNode *op) final;
......
#pragma once #pragma once
#include <cute/numeric/numeric_types.hpp> #include <cuda_fp8.h>
using fp8_e4_t = cute::float_e4m3_t;
using fp8_e4_2_t = __nv_fp8x2_e4m3; using fp8_e4_t = __nv_fp8_e4m3;
using fp8_e4_4_t = __nv_fp8x4_e4m3; struct __CUDA_ALIGN__(2) fp8_e4_2_t {
struct fp8_e4_8_t { fp8_e4_t x;
fp8_e4_t data[8]; fp8_e4_t y;
}; };
struct fp8_e4_16_t {
fp8_e4_t data[16]; struct __CUDA_ALIGN__(4) fp8_e4_4_t {
}; fp8_e4_t x;
using fp8_e5_t = cute::float_e5m2_t; fp8_e4_t y;
using fp8_e5_2_t = __nv_fp8x2_e5m2; fp8_e4_t z;
using fp8_e5_4_t = __nv_fp8x4_e5m2; fp8_e4_t w;
struct fp8_e5_8_t { };
fp8_e5_t data[8];
}; struct __CUDA_ALIGN__(8) fp8_e4_8_t {
struct fp8_e5_16_t { fp8_e4_4_t x;
fp8_e5_t data[16]; fp8_e4_4_t y;
};
struct __CUDA_ALIGN__(16) fp8_e4_16_t {
fp8_e4_8_t x;
fp8_e4_8_t y;
};
using fp8_e5_t = __nv_fp8_e5m2;
struct __CUDA_ALIGN__(2) fp8_e5_2_t {
fp8_e5_t x;
fp8_e5_t y;
};
struct __CUDA_ALIGN__(4) fp8_e5_4_t {
fp8_e5_t x;
fp8_e5_t y;
fp8_e5_t z;
fp8_e5_t w;
};
struct __CUDA_ALIGN__(8) fp8_e5_8_t {
fp8_e5_4_t x;
fp8_e5_4_t y;
};
struct __CUDA_ALIGN__(16) fp8_e5_16_t {
fp8_e5_8_t x;
fp8_e5_8_t y;
}; };
...@@ -120,9 +120,6 @@ private: ...@@ -120,9 +120,6 @@ private:
const DataType &access_type = buffer->dtype; const DataType &access_type = buffer->dtype;
// i // 2, i % 8 can also be vectorized as factor 16 // i // 2, i % 8 can also be vectorized as factor 16
int max_vector_size = vector_load_bits_max_ / access_type.bits(); int max_vector_size = vector_load_bits_max_ / access_type.bits();
if (access_type.is_e4m3_float8() or access_type.is_e5m2_float8()) {
max_vector_size = 1; // [temporarily] do not vectorize float8
}
// so we should disable this GCD optimization // so we should disable this GCD optimization
max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value); max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);
auto last_dim = buffer->shape.back(); auto last_dim = buffer->shape.back();
......
...@@ -132,6 +132,8 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -132,6 +132,8 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tir.transform.InferFragment()(mod) mod = tir.transform.InferFragment()(mod)
mod = tir.transform.LowerThreadAllreduce()(mod) mod = tir.transform.LowerThreadAllreduce()(mod)
mod = tilelang.transform.LowerHopperIntrin()(mod) mod = tilelang.transform.LowerHopperIntrin()(mod)
mod = tilelang.transform.ThreadSync("global")(mod)
mod = tilelang.transform.ThreadSync("shared")(mod) mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod) mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
mod = tilelang.transform.EliminateStorageSyncForMBarrier()(mod) mod = tilelang.transform.EliminateStorageSyncForMBarrier()(mod)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment