"awq/git@developer.sourcefind.cn:OpenDAS/autoawq.git" did not exist on "1c5ccc791fa2cb0697db3b4070df1813f1736208"
Unverified Commit 81b8c1b7 authored by Kuris's avatar Kuris Committed by GitHub
Browse files

[Fix] Fix analyzer bind conflicting (#1446)

parent 869f021b
...@@ -1090,112 +1090,114 @@ private: ...@@ -1090,112 +1090,114 @@ private:
reducer_info = op->annotations.Get(attr::kReducerInfo) reducer_info = op->annotations.Get(attr::kReducerInfo)
->as<Map<Var, ReducerInfo>>() ->as<Map<Var, ReducerInfo>>()
.value(); .value();
if (!result_.for_map.count(tvm::ffi::GetRef<For>(op))) {
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
// the analyzer will be modified in PartitionLoop and VectorizeLoop
// we need to save its state to prevent conflicted bindings
auto saved_analyzer = analyzer_->Clone();
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op)); For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
if (result_.for_map.count(tvm::ffi::GetRef<For>(op))) { auto root = tvm::ffi::GetRef<For>(op);
auto root = tvm::ffi::GetRef<For>(op); // This check is a workaround to support T.Parallel for local buffers.
// This check is a workaround to support T.Parallel for local buffers. // For example:
// For example: // for i in T.Parallel(1024):
// for i in T.Parallel(1024): // A_local[i] = A_global[i]
// A_local[i] = A_global[i] // Here, A_local is a register-local buffer held independently by each
// Here, A_local is a register-local buffer held independently by each // thread, so explicit thread binding is not required.
// thread, so explicit thread binding is not required. bool store_into_local = false;
bool store_into_local = false; PostOrderVisit(root, [&](const ObjectRef &obj) {
PostOrderVisit(root, [&](const ObjectRef &obj) { if (const auto *store = obj.as<BufferStoreNode>()) {
if (const auto *store = obj.as<BufferStoreNode>()) { if (store->buffer.scope() == "local") {
if (store->buffer.scope() == "local") { store_into_local = true;
store_into_local = true;
}
// if the case is like:
// for i in T.Parallel(1024):
// A_local[i] = B_global[i]
// A_frag[i] = A_global[i]
// exception will be raise in Parallel::LayoutInference
} }
}); // if the case is like:
// This check if for the loop that only manuplates "local" buffers, // for i in T.Parallel(1024):
// for i in T.Parallel(1024): // A_local[i] = B_global[i]
// A_local[i] = B_local[i] // A_frag[i] = A_global[i]
// Though this might be illegal // exception will be raise in Parallel::LayoutInference
// We use PostOrderVisit to detect whether the loop only manuplates }
// "local" buffers, which indicates register usage and justifies skipping });
// thread binding. // This check if for the loop that only manuplates "local" buffers,
bool local_register_only = true; // for i in T.Parallel(1024):
PostOrderVisit(root, [&](const ObjectRef &obj) { // A_local[i] = B_local[i]
if (const auto *store = obj.as<BufferStoreNode>()) { // Though this might be illegal
if (store->buffer.scope() != "local") { // We use PostOrderVisit to detect whether the loop only manuplates
local_register_only = false; // "local" buffers, which indicates register usage and justifies skipping
} // thread binding.
} else if (const auto *load = obj.as<BufferLoadNode>()) { bool local_register_only = true;
if (load->buffer.scope() != "local") { PostOrderVisit(root, [&](const ObjectRef &obj) {
local_register_only = false; if (const auto *store = obj.as<BufferStoreNode>()) {
} if (store->buffer.scope() != "local") {
local_register_only = false;
} }
}); } else if (const auto *load = obj.as<BufferLoadNode>()) {
if (load->buffer.scope() != "local") {
local_register_only = false;
}
}
});
auto loop_layout = result_.for_map[root]; auto loop_layout = result_.for_map[root];
// FIXME: tell in-Parallel and out-of-Parallel `local`s apart // FIXME: tell in-Parallel and out-of-Parallel `local`s apart
// NOTE(lei): a bit ugly, we should rethink about this part in future. // NOTE(lei): a bit ugly, we should rethink about this part in future.
bool parallel_loop = bool parallel_loop =
!skip_thread_partition_ && !local_register_only && !store_into_local; !skip_thread_partition_ && !local_register_only && !store_into_local;
if (parallel_loop) { if (parallel_loop) {
for_node = for_node =
PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout); PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
}
// If none thread bindings are provided, partition the loop
bool has_non_local = false;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (const auto *load = obj.as<BufferLoadNode>()) {
String scope = load->buffer.scope();
if (scope != "local" && scope != "local.fragment") {
has_non_local = true;
}
} else if (const auto *store = obj.as<BufferStoreNode>()) {
String scope = store->buffer.scope();
if (scope != "local" && scope != "local.fragment") {
has_non_local = true;
}
} }
// If none thread bindings are provided, partition the loop });
bool has_non_local = false; // Workaround: if reducer is presented, don't vectorize loop
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) { // Best solution should be isolate reduction axis out of vectorization
if (const auto *load = obj.as<BufferLoadNode>()) { bool has_reducer = false;
String scope = load->buffer.scope(); PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (scope != "local" && scope != "local.fragment") { if (!has_reducer)
has_non_local = true; if (const auto *store = obj.as<BufferStoreNode>()) {
} has_reducer = reducer_info.count(store->buffer->data) != 0;
} else if (const auto *store = obj.as<BufferStoreNode>()) {
String scope = store->buffer.scope();
if (scope != "local" && scope != "local.fragment") {
has_non_local = true;
}
} }
}); });
// Workaround: if reducer is presented, don't vectorize loop
// Best solution should be isolate reduction axis out of vectorization // If a cast operation exists, vectorization may still be required
bool has_reducer = false; bool has_cast_operations = false;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) { PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (!has_reducer) if (const auto *cast = obj.as<CastNode>()) {
if (const auto *store = obj.as<BufferStoreNode>()) { // Check if this is a non-reducer store with Cast operation
has_reducer = reducer_info.count(store->buffer->data) != 0; DataType src_type = cast->value.dtype();
} DataType dst_type = cast->dtype;
}); bool src_ok = src_type.is_float() || src_type.is_bfloat() ||
src_type.is_float8_e4m3() || src_type.is_float8_e5m2();
// If a cast operation exists, vectorization may still be required bool dst_ok = dst_type.is_float() || dst_type.is_bfloat() ||
bool has_cast_operations = false; dst_type.is_float8_e4m3() || dst_type.is_float8_e5m2();
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) { if (src_ok && dst_ok && TargetIsCuda(Target::Current())) {
if (const auto *cast = obj.as<CastNode>()) { has_cast_operations = true;
// Check if this is a non-reducer store with Cast operation
DataType src_type = cast->value.dtype();
DataType dst_type = cast->dtype;
bool src_ok = src_type.is_float() || src_type.is_bfloat() ||
src_type.is_float8_e4m3() || src_type.is_float8_e5m2();
bool dst_ok = dst_type.is_float() || dst_type.is_bfloat() ||
dst_type.is_float8_e4m3() || dst_type.is_float8_e5m2();
if (src_ok && dst_ok && TargetIsCuda(Target::Current())) {
has_cast_operations = true;
}
} }
});
if ((has_non_local || has_cast_operations) && !has_reducer) {
for_node = VectorizeLoop(for_node, analyzer_);
} }
});
if (result_.predicate_map.count(root) && parallel_loop) { if ((has_non_local || has_cast_operations) && !has_reducer) {
return IfThenElse(result_.predicate_map[root], for_node); for_node = VectorizeLoop(for_node, saved_analyzer.get());
} else { }
return for_node;
} if (result_.predicate_map.count(root) && parallel_loop) {
return IfThenElse(result_.predicate_map[root], for_node);
} else {
return for_node;
} }
return for_node;
} }
Stmt VisitStmt_(const AttrStmtNode *op) final { Stmt VisitStmt_(const AttrStmtNode *op) final {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment