"examples/git@developer.sourcefind.cn:OpenDAS/fairseq.git" did not exist on "af6b361c92178458e1fe938bf8247080f38e44b4"
Commit 6d3d4743 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[CI] Update CI configuration to run pytest with automatic parallelization (#393)

* Update CI configuration to run pytest with automatic parallelization using the '-n auto' option.

* Enhance Cython JIT Adapter Compilation Logic

- Improved the locking mechanism during the compilation of the Cython JIT adapter to prevent race conditions.
- Added checks to determine if another process has already compiled the library, reducing unnecessary recompilation.
- Cleaned up the code by removing redundant imports and ensuring proper handling of temporary files during compilation failures.
- Updated vectorization logic in loop_vectorize.cc to allow optional simplification of vectorized expressions.

This update enhances performance and reliability in the JIT compilation process.

* lint fix

* Update CI configuration to run pytest with 4 parallel jobs instead of auto-detection

* Add pytest markers for serial execution in MHA tests

- Added @pytest.mark.serial to multiple MHA test functions to ensure they run sequentially.
- This change improves test reliability by preventing potential race conditions during execution.

* Update TVM submodule and enhance vectorization logic in loop_vectorize.cc

- Updated the TVM submodule to the latest commit.
- Modified the vectorization logic to include optional simplification of vectorized expressions and added checks to ensure the usage of vectorized variables, improving performance and reliability in expression handling.

* Remove @pytest.mark.serial from multiple MHA test functions to allow parallel execution. This change enhances test performance by enabling concurrent test runs while maintaining reliability.

* Remove tvm_simplify_test.py file, eliminating the test for expression simplification in TVM. This cleanup helps streamline the codebase by removing unused test cases.

* Remove unused pytest import from test_tilelang_kernel_mha.py to streamline the test file.

* lint fix

* Update TVM submodule and refine vectorization logic in loop_vectorize.cc

- Updated the TVM submodule to the latest commit.
- Adjusted the return statements in loop_vectorize.cc to improve expression handling and ensure consistency in the visitor pattern.

* Refactor vectorization logic in loop_vectorize.cc

- Removed the check for the usage of the vectorized variable in the vectorization logic, simplifying the expression handling.
- This change enhances the clarity and efficiency of the vectorization process.

* Enhance vectorization checks in loop_vectorize.cc

- Added a check to ensure the vectorized expression uses the vectorized variable, improving the robustness of the vectorization logic.
- This change refines the expression handling and ensures that only valid vectorized expressions are processed.

* Implement non-local buffer checks for loop vectorization in layout_inference.cc

- Added logic to check for non-local buffer loads and stores before applying vectorization to loops. This enhancement ensures that vectorization is only applied when appropriate, improving the correctness of the loop transformations.

* Refactor buffer handling in pipeline planning and layout inference

- Renamed GlobalCopyPatternDetector to BufferRegionCollector for clarity and updated its logic to collect buffer read/write regions.
- Enhanced the handling of conditional expressions in pipeline planning, allowing for better management of stages related to conditional statements.
- Improved the processing of buffer regions during read/write operations, ensuring accurate tracking of buffer usage across different stages.

* Refactor vectorization checks in loop_vectorize.cc

- Removed the check for the usage of the vectorized variable in the vectorization logic, simplifying the expression handling.
- This change enhances the clarity and efficiency of the vectorization process, ensuring that valid vectorized expressions are processed without unnecessary checks.
parent 5f7bfeab
Subproject commit 6df0b88a90cd7a931ba171b57f0cf41d4cbfa2fe Subproject commit 742ed56bc08503c86b75bbd2a80e04db40e8600a
...@@ -556,7 +556,24 @@ private: ...@@ -556,7 +556,24 @@ private:
PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout); PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
} }
// If none thread bindings are provided, partition the loop // If none thread bindings are provided, partition the loop
for_node = VectorizeLoop(for_node); bool has_non_local = false;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (const auto *load = obj.as<BufferLoadNode>()) {
String scope = load->buffer.scope();
if (scope != "local" && scope != "local.fragment") {
has_non_local = true;
}
} else if (const auto *store = obj.as<BufferStoreNode>()) {
String scope = store->buffer.scope();
if (scope != "local" && scope != "local.fragment") {
has_non_local = true;
}
}
});
if (has_non_local) {
for_node = VectorizeLoop(for_node);
}
if (result_.predicate_map.count(root) && parallel_loop) { if (result_.predicate_map.count(root) && parallel_loop) {
return IfThenElse(result_.predicate_map[root], for_node); return IfThenElse(result_.predicate_map[root], for_node);
......
...@@ -80,7 +80,6 @@ private: ...@@ -80,7 +80,6 @@ private:
} }
} }
UpdateVectorSize(node->indices, node->buffer); UpdateVectorSize(node->indices, node->buffer);
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
} }
void VisitStmt_(const BufferStoreNode *node) final { void VisitStmt_(const BufferStoreNode *node) final {
...@@ -88,7 +87,7 @@ private: ...@@ -88,7 +87,7 @@ private:
node->buffer.scope() == "shared.dyn") node->buffer.scope() == "shared.dyn")
has_nonlocal_memory_access_ = true; has_nonlocal_memory_access_ = true;
UpdateVectorSize(node->indices, node->buffer); UpdateVectorSize(node->indices, node->buffer);
return arith::IRVisitorWithAnalyzer::VisitStmt_(node); return arith::IRVisitorWithAnalyzer::VisitExpr(node->value);
} }
void VisitStmt_(const IfThenElseNode *node) final { void VisitStmt_(const IfThenElseNode *node) final {
...@@ -242,11 +241,14 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size, ...@@ -242,11 +241,14 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size,
iter_var_size, target_vectorized_size)))); iter_var_size, target_vectorized_size))));
PrimExpr expr_transformed = analyzer->Simplify( PrimExpr expr_transformed = analyzer->Simplify(
Substitute(expr, {{var, v0 + v1 * target_vectorized_size}})); Substitute(expr, {{var, v0 + v1 * target_vectorized_size}}));
PrimExpr expr_simplified = analyzer->Simplify(expr_transformed);
Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size)); Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size));
PrimExpr expr_vectorized = PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed);
analyzer->Simplify(vectorizer.VisitExpr(expr_transformed));
auto ramp_node = expr_vectorized.as<RampNode>(); auto ramp_node = expr_vectorized.as<RampNode>();
if (!ramp_node) { if (!ramp_node) {
expr_vectorized = analyzer->Simplify(expr_vectorized);
// Broadcast value // Broadcast value
if (expr_vectorized.dtype().lanes() == 1) if (expr_vectorized.dtype().lanes() == 1)
return true; return true;
......
...@@ -61,17 +61,31 @@ bool MayConflict(Region region1, Region region2) { ...@@ -61,17 +61,31 @@ bool MayConflict(Region region1, Region region2) {
* 2. Source buffer must be in global memory scope * 2. Source buffer must be in global memory scope
* 3. Destination buffer must be in local or shared memory scope * 3. Destination buffer must be in local or shared memory scope
*/ */
class GlobalCopyPatternDetector : public StmtExprVisitor { class BufferRegionCollector : public StmtExprVisitor {
public: public:
static bool Detect(const Stmt &stmt) { BufferRegionCollector(Map<Var, Buffer> buffer_data_to_buffer)
GlobalCopyPatternDetector detector; : buffer_data_to_buffer_(buffer_data_to_buffer) {}
detector.VisitStmt(stmt);
return detector.is_global_copy_pattern_; Array<BufferRegion> GetReads() const { return reads_; }
}
Array<BufferRegion> GetWrites() const { return writes_; }
bool GetGlobalCopyPattern() const { return is_global_copy_pattern_; }
PrimExpr GetConditonalExpr() const { return conditonal_expr; }
private: private:
void VisitStmt_(const BufferStoreNode *op) final { void VisitStmt_(const BufferStoreNode *op) final {
Buffer store_buffer = op->buffer; Buffer store_buffer = op->buffer;
Array<PrimExpr> indices = op->indices;
// convert indices to region
Array<Range> region;
for (const auto &index : indices) {
region.push_back(Range::FromMinExtent(index, 1));
}
auto store_region = BufferRegion(store_buffer, region);
writes_.push_back(store_region);
is_global_read_ = false; is_global_read_ = false;
this->VisitExpr(op->value); this->VisitExpr(op->value);
if (is_global_read_ && (store_buffer.scope() == "shared" || if (is_global_read_ && (store_buffer.scope() == "shared" ||
...@@ -83,6 +97,16 @@ private: ...@@ -83,6 +97,16 @@ private:
} }
void VisitExpr_(const BufferLoadNode *op) final { void VisitExpr_(const BufferLoadNode *op) final {
auto load_buffer = op->buffer;
Array<PrimExpr> indices = op->indices;
// convert indices to region
Array<Range> region;
for (const auto &index : indices) {
region.push_back(Range::FromMinExtent(index, 1));
}
auto load_region = BufferRegion(load_buffer, region);
reads_.push_back(load_region);
if (op->buffer.scope() == "global") { if (op->buffer.scope() == "global") {
is_global_read_ = true; is_global_read_ = true;
} }
...@@ -90,7 +114,22 @@ private: ...@@ -90,7 +114,22 @@ private:
void VisitExpr_(const CallNode *op) final { void VisitExpr_(const CallNode *op) final {
auto args = op->args; auto args = op->args;
if (op->op.same_as(tir::builtin::if_then_else())) { if (op->op.same_as(builtin::address_of())) {
const BufferLoad load = Downcast<BufferLoad>(op->args[0]);
const BufferRegion buffer_region = BufferRegion::FullRegion(load->buffer);
// because we only care about the buffer itself instead of indices
reads_.push_back(buffer_region);
} else if (op->op.same_as(builtin::tvm_access_ptr())) {
const VarNode *buffer_var = op->args[1].as<VarNode>();
ICHECK(buffer_var);
auto it = buffer_data_to_buffer_.find(GetRef<Var>(buffer_var));
if (it != buffer_data_to_buffer_.end()) {
const Buffer &buffer = (*it).second;
const BufferRegion buffer_region = BufferRegion::FullRegion(buffer);
// because we only care about the buffer itself instead of indices
reads_.push_back(buffer_region);
}
} else if (op->op.same_as(tir::builtin::if_then_else())) {
// Simplify nested if_then_else // Simplify nested if_then_else
// if (cond) { if (inner_cond) { inner_then_expr } else { inner_else_expr // if (cond) { if (inner_cond) { inner_then_expr } else { inner_else_expr
// } } else { else_expr } // } } else { else_expr }
...@@ -98,23 +137,31 @@ private: ...@@ -98,23 +137,31 @@ private:
const PrimExpr &cond = op->args[0]; const PrimExpr &cond = op->args[0];
const PrimExpr &then_expr = op->args[1]; const PrimExpr &then_expr = op->args[1];
const PrimExpr &else_expr = op->args[2]; const PrimExpr &else_expr = op->args[2];
conditonal_expr = cond;
this->VisitExpr(then_expr); this->VisitExpr(then_expr);
this->VisitExpr(else_expr); this->VisitExpr(else_expr);
} else {
StmtExprVisitor::VisitExpr_(op);
} }
} }
void VisitStmt_(const IfThenElseNode *op) final { void VisitStmt_(const IfThenElseNode *op) final {
// Skip condition // Skip condition
this->VisitStmt(op->then_case); this->VisitStmt(op->then_case);
conditonal_expr = op->condition;
if (op->else_case.defined()) { if (op->else_case.defined()) {
this->VisitStmt(op->else_case.value()); this->VisitStmt(op->else_case.value());
} }
} }
private: private:
Map<Var, Buffer> buffer_data_to_buffer_;
Array<BufferRegion> reads_;
Array<BufferRegion> writes_;
bool is_global_read_ = false; bool is_global_read_ = false;
bool under_buffer_store_ = false; bool under_buffer_store_ = false;
bool is_global_copy_pattern_ = false; bool is_global_copy_pattern_ = false;
PrimExpr conditonal_expr;
}; };
class PipelinePlanner : public StmtExprMutator { class PipelinePlanner : public StmtExprMutator {
...@@ -151,7 +198,10 @@ private: ...@@ -151,7 +198,10 @@ private:
int original_order; int original_order;
int order = -1, stage = -1; int order = -1, stage = -1;
bool copy_stage = false; bool copy_stage = false;
bool prepare_for_condition = false;
int last_use_stage = -1; int last_use_stage = -1;
// represent the stage is used in a conditional statement
PrimExpr conditonal_expr;
}; };
PipelineStageInfo MakePipelineStageInfo(Stmt stmt, int idx) { PipelineStageInfo MakePipelineStageInfo(Stmt stmt, int idx) {
...@@ -159,13 +209,14 @@ private: ...@@ -159,13 +209,14 @@ private:
/*body*/ stmt); /*body*/ stmt);
Array<Array<BufferRegion>> access = Array<Array<BufferRegion>> access =
GetBlockReadWriteRegion(block, buffer_data_to_buffer_); GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
auto collector = BufferRegionCollector(buffer_data_to_buffer_);
collector(block);
PipelineStageInfo pinfo; PipelineStageInfo pinfo;
pinfo.reads = std::move(access[0]); pinfo.reads = std::move(collector.GetReads());
pinfo.writes = std::move(access[1]); pinfo.writes = std::move(collector.GetWrites());
pinfo.original_order = idx; pinfo.original_order = idx;
pinfo.copy_stage = GlobalCopyPatternDetector::Detect(stmt); pinfo.copy_stage = collector.GetGlobalCopyPattern();
pinfo.conditonal_expr = collector.GetConditonalExpr();
return std::move(pinfo); return std::move(pinfo);
} }
...@@ -235,6 +286,25 @@ private: ...@@ -235,6 +286,25 @@ private:
pipeline_stage_infos.push_back(std::move(pinfo)); pipeline_stage_infos.push_back(std::move(pinfo));
} }
// process the conditional stage
// assign conditional stage (analysis the copy stage)
for (auto &pinfo : pipeline_stage_infos) {
for (const auto &write : pinfo.writes) {
for (const auto &other : pipeline_stage_infos) {
if (other.conditonal_expr.defined()) {
auto check_var = [&](const ObjectRef &n) {
if (const auto *buffer_load = n.as<BufferLoadNode>()) {
if (buffer_load->buffer == write->buffer) {
pinfo.prepare_for_condition = true;
}
}
};
PostOrderVisit(other.conditonal_expr, check_var);
}
}
}
}
// analysis use-def chain // analysis use-def chain
for (auto &pinfo : pipeline_stage_infos) { for (auto &pinfo : pipeline_stage_infos) {
for (int i = pinfo.original_order + 1; for (int i = pinfo.original_order + 1;
...@@ -269,47 +339,51 @@ private: ...@@ -269,47 +339,51 @@ private:
// Making stages and orders // Making stages and orders
int order_idx = 0; int order_idx = 0;
// Create pipeline stages and assign order
for (auto &pinfo : pipeline_stage_infos) { for (auto &pinfo : pipeline_stage_infos) {
if (pinfo.copy_stage && pinfo.last_use_stage != -1) // Skip elements that must be in first stage:
// 1. Copy stages (with active last_use_stage)
// 2. Condition preparation stages
if ((pinfo.copy_stage && pinfo.last_use_stage != -1) ||
pinfo.prepare_for_condition)
continue; continue;
// Main logic stage assignment:
// - Increment order index
// - Assign to new stage (current num_stages)
pinfo.order = order_idx++; pinfo.order = order_idx++;
pinfo.stage = num_stages; pinfo.stage = num_stages;
bool used_by_copy = false;
for (const auto &write : pinfo.writes) {
for (const auto &other : pipeline_stage_infos) {
if (other.copy_stage) {
for (const auto &read : other.reads) {
if (write->buffer == read->buffer &&
MayConflict(write->region, read->region)) {
used_by_copy = true;
break;
}
}
}
}
}
if (used_by_copy) {
pinfo.stage = 0;
}
for (auto &pinfo_1 : pipeline_stage_infos) { for (auto &pinfo_1 : pipeline_stage_infos) {
if (pinfo_1.copy_stage && if ((pinfo_1.copy_stage &&
pinfo_1.last_use_stage == pinfo.original_order) { pinfo_1.last_use_stage == pinfo.original_order)) {
pinfo_1.order = order_idx++; pinfo_1.order = order_idx++;
pinfo_1.stage = 0; pinfo_1.stage = 0;
} }
} }
} }
// process the tail copy stage
// Handle trailing unassigned copy stages:
// These are typically final copy operations needing post-main-stage
// insertion
auto &head_pinfo = pipeline_stage_infos.at(0); auto &head_pinfo = pipeline_stage_infos.at(0);
if (head_pinfo.order == -1) { int unassigned_order_elem = -1;
for (auto &pinfo : pipeline_stage_infos) {
pinfo.order++; // Process dependent copy stages:
// Insert copy stages after current stage but assign to stage 0
// and adjust the order index
for (auto &pinfo : pipeline_stage_infos) {
if (pinfo.order == unassigned_order_elem) {
pinfo.order = unassigned_order_elem++;
// traverse the from the next info
for (auto it = pipeline_stage_infos.begin() + unassigned_order_elem;
it != pipeline_stage_infos.end(); it++) {
it->order += 1;
}
pinfo.stage = 0;
order_idx++;
} }
head_pinfo.stage = 0;
order_idx++;
} }
ICHECK(size_t(order_idx) == pipeline_stage_infos.size()) ICHECK(size_t(order_idx) == pipeline_stage_infos.size())
......
...@@ -20,6 +20,7 @@ import sys ...@@ -20,6 +20,7 @@ import sys
import sysconfig import sysconfig
import hashlib import hashlib
import os import os
import fcntl
from pathlib import Path from pathlib import Path
import logging import logging
...@@ -61,7 +62,6 @@ def get_cached_lib(source_code: str) -> Tuple[Optional[ctypes.CDLL], Path]: ...@@ -61,7 +62,6 @@ def get_cached_lib(source_code: str) -> Tuple[Optional[ctypes.CDLL], Path]:
code_hash = hashlib.sha256(source_code.encode()).hexdigest() code_hash = hashlib.sha256(source_code.encode()).hexdigest()
cache_path = get_cache_dir() / f"{code_hash}.so" cache_path = get_cache_dir() / f"{code_hash}.so"
lock_file = cache_path.with_suffix('.lock') lock_file = cache_path.with_suffix('.lock')
import fcntl
with open(lock_file, 'w') as lock: with open(lock_file, 'w') as lock:
fcntl.flock(lock.fileno(), fcntl.LOCK_EX) fcntl.flock(lock.fileno(), fcntl.LOCK_EX)
try: try:
...@@ -91,6 +91,7 @@ with open(cython_wrapper_path, "r") as f: ...@@ -91,6 +91,7 @@ with open(cython_wrapper_path, "r") as f:
md5_path = cache_dir / "md5.txt" md5_path = cache_dir / "md5.txt"
code_hash = hashlib.sha256(cython_wrapper_code.encode()).hexdigest() code_hash = hashlib.sha256(cython_wrapper_code.encode()).hexdigest()
cache_path = cache_dir / f"{code_hash}.so" cache_path = cache_dir / f"{code_hash}.so"
lock_file = cache_path.with_suffix('.lock')
# Check if cached version exists and is valid # Check if cached version exists and is valid
need_compile = True need_compile = True
...@@ -106,32 +107,45 @@ with open(cython_wrapper_path, "r") as f: ...@@ -106,32 +107,45 @@ with open(cython_wrapper_path, "r") as f:
logger.info("No cached version found for cython jit adapter, need to compile...") logger.info("No cached version found for cython jit adapter, need to compile...")
if need_compile: if need_compile:
logger.info("Compiling cython jit adapter...") logger.info("Waiting for lock to compile cython jit adapter...")
temp_path = cache_dir / f"temp_{code_hash}.so" with open(lock_file, 'w') as lock:
try: fcntl.flock(lock.fileno(), fcntl.LOCK_EX)
with open(md5_path, "w") as f: try:
f.write(code_hash) # After acquiring the lock, check again if the file has been compiled by another process
if md5_path.exists() and library_path.exists():
# compile the cython_wrapper.pyx file into .cpp with open(md5_path, "r") as f:
cython = get_cython_compiler() cached_hash = f.read().strip()
if cython is None: if cached_hash == code_hash:
raise Exception("Cython is not installed, please install it first.") logger.info(
os.system(f"{cython} {cython_wrapper_path} --cplus -o {source_path}") "Another process has already compiled the file, using it...")
python_include_path = sysconfig.get_path("include") need_compile = False
cc = get_cplus_compiler()
command = f"{cc} -shared -pthread -fPIC -fwrapv -O2 -Wall -fno-strict-aliasing -I{python_include_path} {source_path} -o {temp_path}" if need_compile:
os.system(command) logger.info("Compiling cython jit adapter...")
temp_path = cache_dir / f"temp_{code_hash}.so"
# rename the temp file to the library file
temp_path.rename(library_path) with open(md5_path, "w") as f:
except Exception as e: f.write(code_hash)
if temp_path.exists():
temp_path.unlink() # compile the cython_wrapper.pyx file into .cpp
raise Exception(f"Failed to compile cython jit adapter: {e}") from e cython = get_cython_compiler()
finally: if cython is None:
lock_file = cache_path.with_suffix('.lock') raise Exception("Cython is not installed, please install it first.")
if lock_file.exists(): os.system(f"{cython} {cython_wrapper_path} --cplus -o {source_path}")
lock_file.unlink() python_include_path = sysconfig.get_path("include")
cc = get_cplus_compiler()
command = f"{cc} -shared -pthread -fPIC -fwrapv -O2 -Wall -fno-strict-aliasing -I{python_include_path} {source_path} -o {temp_path}"
os.system(command)
# rename the temp file to the library file
temp_path.rename(library_path)
except Exception as e:
if 'temp_path' in locals() and temp_path.exists():
temp_path.unlink()
raise Exception(f"Failed to compile cython jit adapter: {e}") from e
finally:
if lock_file.exists():
lock_file.unlink()
# add the .so file to the sys.path # add the .so file to the sys.path
cache_dir_str = str(cache_dir) cache_dir_str = str(cache_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment