"...git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "17f7394fc684206f675aa880c3ed528cb166f259"
Commit dda8ebff authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Enhancing the handling of conditional statements in the pipeline (#201)

* Optimize CMake build process with dynamic job count calculation

- Modify build_csrc function to use 90% of available CPU cores
- Ensure at least one job is used during compilation
- Improve build performance by dynamically adjusting parallel job count

* Optimize build_csrc function with multiprocessing module

- Replace os.cpu_count() with multiprocessing.cpu_count()
- Maintain existing 90% CPU utilization logic
- Improve CPU core count calculation for build process

* Add dynamic shape support with out_idx in Cython JIT kernel compilation

- Implement `run_cython_dynamic_shape_with_out_idx` function in test_tilelang_jit_gemm_cython.py
- Update Cython wrapper to handle dynamic symbolic shapes during tensor allocation
- Add support for resolving dynamic shape dimensions using input tensor references
- Enhance flexibility of JIT kernel compilation with symbolic shape handling

* Enhance error reporting for dynamic symbolic shape resolution in Cython JIT kernel

- Add detailed error message when a dynamic symbolic dimension is not found in dynamic_symbolic_map
- Improve debugging by providing context about missing symbolic dimensions
- Maintain existing dynamic shape resolution logic

* Fix Copy operation handling for scalar and multi-dimensional tensors

- Add special handling for scalar tensor copy operations
- Enhance error reporting in MakeIndices method with more detailed diagnostic information
- Improve SIMT loop generation to support zero-dimensional tensors
- Add explicit check and handling for scalar tensor scenarios

* Refactor Copy operation code formatting and improve readability

- Improve code formatting in MakeIndices and MakeSIMTLoop methods
- Add line breaks to enhance readability of complex ICHECK statements
- Simplify code structure in scalar tensor handling
- Remove unnecessary whitespace and improve code alignment

* Simplify GEMM example with direct kernel compilation

- Update copyright header to Tile-AI Corporation
- Remove Profiler import and usage
- Replace tilelang.lower() with tilelang.compile()
- Simplify kernel execution workflow
- Update kernel source retrieval method

* Enhance block sparse attention implementation

- Update `blocksparse_flashattn` to use 2 stages for improved performance.
- Change `block_mask_dtype` from `int8` to `bool` for better memory efficiency.
- Modify condition checks in the kernel to utilize boolean values.
- Introduce a new example for top-k sparse attention and a benchmark for native sparse attention.
- Add support for asynchronous copy in PTX and improve pipeline planning with condition handling.

* Refactor and clean up code formatting across multiple files

- Added whitespace for improved readability in `example_blocksparse_gemm.py`, `example_tilelang_nsa_fwd.py`, and `benchmark_nsa_fwd.py`.
- Enhanced code structure and alignment in `inject_ptx_async_copy.cc` and `pipeline_planning.cc`.
- Updated comments and documentation for clarity in `__init__.py` and `phase.py`.
- Ensured consistent formatting and style across the codebase.

* Add kernel source printing in example_tilelang_nsa_fwd.py and implement IfThenElse node replacement in inject_pipeline.cc

- Added a print statement to output the kernel source in `example_tilelang_nsa_fwd.py` for debugging purposes.
- Introduced a new function `replace_if_then_else` in `inject_pipeline.cc` to transform IfThenElse nodes while preserving attributes, enhancing the handling of conditional statements in the pipeline.

* Refactor condition handling in inject_pipeline.cc

- Change the data structure for mapping conditions to statements from a Map to an Array for improved performance and simplicity.
- Update condition comparison logic to use StructuralEqual for better accuracy.
- Enhance logging to provide detailed insights into condition changes and statement processing.
- Adjust final statement construction to utilize the new data structure, ensuring correct handling of conditions and statements.

* Improve logging and formatting in inject_pipeline.cc

- Enhance logging statements for better clarity on condition changes and statement processing.
- Adjust formatting for improved readability, including line breaks and consistent spacing.
- Ensure accurate condition comparison and handling in the pipeline logic.

* Refactor logging and clean up inject_pipeline.cc

- Remove excessive logging statements to streamline the code and improve performance.
- Simplify condition handling by eliminating unnecessary log outputs related to condition changes and statement processing.
- Maintain the core functionality while enhancing code readability and maintainability.
parent c2b9b59d
...@@ -140,7 +140,7 @@ if __name__ == "__main__": ...@@ -140,7 +140,7 @@ if __name__ == "__main__":
scale=scale, scale=scale,
) )
kernel = tilelang.compile(program, out_idx=-1) kernel = tilelang.compile(program, out_idx=-1)
print(kernel.get_kernel_source())
torch.random.manual_seed(0) torch.random.manual_seed(0)
Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True)
......
...@@ -80,6 +80,35 @@ struct BufferAccessInfo { ...@@ -80,6 +80,35 @@ struct BufferAccessInfo {
int use = -1; // the last using stage of the buffer int use = -1; // the last using stage of the buffer
}; };
/*!
* \brief Replace IfThenElse nodes with their then_case, preserving attribute
* nodes \param body The statement to process \param condition The condition to
* match in IfThenElse nodes \return The transformed statement
*/
Stmt replace_if_then_else(Stmt body, PrimExpr condition) {
if (const auto *if_node = body.as<IfThenElseNode>()) {
// If this is an IfThenElse with the matching condition, replace it with its
// then_case
if (if_node->condition.same_as(condition)) {
return if_node->then_case;
}
} else if (const auto *attr_node = body.as<AttrStmtNode>()) {
// For attribute nodes, preserve the attribute but process its body
AttrStmt attr_stmt = GetRef<AttrStmt>(attr_node);
attr_stmt.CopyOnWrite()->body =
replace_if_then_else(attr_node->body, condition);
return attr_stmt;
} else if (const auto *block_node = body.as<BlockNode>()) {
// For block nodes, process the body
Block block = GetRef<Block>(block_node);
block.CopyOnWrite()->body =
replace_if_then_else(block_node->body, condition);
return block;
}
// For any other node type, return it unchanged
return body;
}
/*! /*!
* \brief Rewriter for the body of the software pipeline. This pass inserts * \brief Rewriter for the body of the software pipeline. This pass inserts
* `floormod` to indices of the remapped buffer to select the version * `floormod` to indices of the remapped buffer to select the version
...@@ -620,7 +649,6 @@ private: ...@@ -620,7 +649,6 @@ private:
bool need_bound_check) { bool need_bound_check) {
PrimExpr new_loop_var; PrimExpr new_loop_var;
PrimExpr extent = end - start; PrimExpr extent = end - start;
auto make_nop = []() { auto make_nop = []() {
return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {})); return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {}));
}; };
...@@ -693,16 +721,125 @@ private: ...@@ -693,16 +721,125 @@ private:
} }
PopulateWaitCounts(new_blocks, &async_states_local); PopulateWaitCounts(new_blocks, &async_states_local);
auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local); auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local);
// Group blocks by their predicate conditions
PrimExpr current_condition = Bool(true);
Array<Stmt> current_stmts;
Array<PrimExpr> ordered_conditions;
Array<Array<Stmt>> condition_to_stmts;
for (const auto &stmt : stmts) {
if (const auto *realize = stmt.as<BlockRealizeNode>()) {
// Helper function to find IfThenElse through potential AttrStmt nodes
auto find_if_then_else =
[](Stmt body) -> std::pair<bool, const IfThenElseNode *> {
while (true) {
if (const auto *if_node = body.as<IfThenElseNode>()) {
return {true, if_node};
} else if (const auto *attr_node = body.as<AttrStmtNode>()) {
// Continue traversing through attributes
body = attr_node->body;
} else {
// No IfThenElse found
return {false, nullptr};
}
}
};
auto [has_if, if_then_else] = find_if_then_else(realize->block->body);
if (has_if) {
if (if_then_else->else_case.defined()) {
// IfThenElse nodes with else case are treated individually
if (!current_stmts.empty()) {
ordered_conditions.push_back(current_condition);
condition_to_stmts.push_back(current_stmts);
current_stmts = {};
}
current_condition = Bool(true);
current_stmts.push_back(stmt);
} else {
// If we encounter a new condition
if (!StructuralEqual()(if_then_else->condition,
current_condition)) {
// Store the current group if it's not empty
if (!current_stmts.empty()) {
ordered_conditions.push_back(current_condition);
condition_to_stmts.push_back(current_stmts);
current_stmts = {};
}
current_condition = if_then_else->condition;
}
BlockRealize new_realize = Downcast<BlockRealize>(stmt);
new_realize.CopyOnWrite()->block.CopyOnWrite()->body =
replace_if_then_else(new_realize->block->body,
if_then_else->condition);
current_stmts.push_back(new_realize);
}
} else {
if (!current_stmts.empty()) {
ordered_conditions.push_back(current_condition);
condition_to_stmts.push_back(current_stmts);
current_stmts = {};
}
current_condition = Bool(true);
current_stmts.push_back(stmt);
}
} else {
// Non-BlockRealize statements are treated individually
if (!current_stmts.empty()) {
ordered_conditions.push_back(current_condition);
condition_to_stmts.push_back(current_stmts);
current_stmts = {};
}
current_condition = Bool(true);
current_stmts.push_back(stmt);
}
}
// Add the last group if not empty
if (!current_stmts.empty()) {
ordered_conditions.push_back(current_condition);
condition_to_stmts.push_back(current_stmts);
}
// Build the final statement sequence with proper conditionals
Array<Stmt> final_stmts;
for (auto i = 0; i < ordered_conditions.size(); i++) {
Array<Stmt> condition_stmts = condition_to_stmts[i];
if (condition_stmts.empty())
continue;
// Create a sequence from the statements with this condition
Stmt stmt_block;
if (condition_stmts.size() == 1) {
stmt_block = condition_stmts[0];
} else {
stmt_block = SeqStmt(condition_stmts);
}
// If condition is not trivially true, wrap in if-then-else
if (!is_one(ordered_conditions[i]) &&
!analyzer_.CanProve(ordered_conditions[i] == true)) {
stmt_block = IfThenElse(ordered_conditions[i], stmt_block);
}
final_stmts.push_back(stmt_block);
}
// Use final_stmts instead of the original stmts
Stmt new_loop{nullptr}; Stmt new_loop{nullptr};
if (stmts.empty()) { if (final_stmts.empty()) {
return make_nop(); return make_nop();
} }
if (stmts.size() == 1) {
new_loop = stmts[0]; if (final_stmts.size() == 1) {
new_loop = final_stmts[0];
} else { } else {
new_loop = SeqStmt(stmts); new_loop = SeqStmt(final_stmts);
} }
if (!is_unit_loop) { if (!is_unit_loop) {
...@@ -979,11 +1116,11 @@ private: ...@@ -979,11 +1116,11 @@ private:
} }
if (has_stage) { if (has_stage) {
LOG(FATAL) LOG(FATAL)
<< "ValueError: Order of the software pipeline is not defined."; << "ValueError: Stage of the software pipeline is not defined.";
} }
if (has_order) { if (has_order) {
LOG(FATAL) LOG(FATAL)
<< "ValueError: Stage of the software pipeline is not defined."; << "ValueError: Order of the software pipeline is not defined.";
} }
return false; return false;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment