"test/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "4c1183c3a88db756f0ef2036c34e04e82554678c"
Commit 156ff85e authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Put thread_extent into reduce (#640)

* [Enhancement] Update AllReduce operation to include thread offset in kernel generation

- Modified the `ReduceOp::Lower` method to incorporate the thread offset in the AllReduce kernel generation for the sm_90 architecture.
- This change improves the accuracy of thread management during reduction operations, enhancing performance on specific GPU architectures.

* [Enhancement] Refactor thread offset handling in AllReduce kernel generation

- Updated the `ReduceOp::Lower` method to streamline the handling of thread offset for AllReduce operations, ensuring consistent usage across different architectures.
- This change enhances code clarity and maintains performance improvements for the sm_90 architecture by reducing redundancy in thread offset calculations.
parent b5ac9bba
...@@ -225,15 +225,16 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -225,15 +225,16 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
std::stringstream ss; std::stringstream ss;
bool has_arch = T.target->attrs.count("arch") > 0; bool has_arch = T.target->attrs.count("arch") > 0;
auto thread_offset = T.thread_bounds->min;
if (has_arch && Downcast<String>(T.target->attrs["arch"]) == "sm_90") { if (has_arch && Downcast<String>(T.target->attrs["arch"]) == "sm_90") {
auto all_threads = T.thread_bounds->extent; auto all_threads = T.thread_bounds->extent;
ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", " ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
<< reducing_threads << ", " << (*scale) << ", " << all_threads << reducing_threads << ", " << (*scale) << ", " << thread_offset
<< ">::run_hopper"; << ", " << all_threads << ">::run_hopper";
} else { } else {
ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", " ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
<< reducing_threads << ", " << (*scale) << ", " << reducing_threads << ", " << (*scale) << ", " << thread_offset
<< (T.thread_bounds->min) << ">::run"; << ">::run";
} }
Array<PrimExpr> thread_reduce_args = { Array<PrimExpr> thread_reduce_args = {
StringImm(ss.str()), BufferLoad(clear_buffer, dst_indices)}; StringImm(ss.str()), BufferLoad(clear_buffer, dst_indices)};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment