"src/include/blockwise_2d_tensor_op.cuh" did not exist on "df228b3cf514ec23dcc1decacfc1973e7f9016d9"
Unverified Commit 81b8c1b7 authored by Kuris's avatar Kuris Committed by GitHub
Browse files

[Fix] Fix analyzer bind conflicting (#1446)

parent 869f021b
......@@ -1090,112 +1090,114 @@ private:
reducer_info = op->annotations.Get(attr::kReducerInfo)
->as<Map<Var, ReducerInfo>>()
.value();
if (!result_.for_map.count(tvm::ffi::GetRef<For>(op))) {
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
// the analyzer will be modified in PartitionLoop and VectorizeLoop
// we need to save its state to prevent conflicted bindings
auto saved_analyzer = analyzer_->Clone();
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
if (result_.for_map.count(tvm::ffi::GetRef<For>(op))) {
auto root = tvm::ffi::GetRef<For>(op);
// This check is a workaround to support T.Parallel for local buffers.
// For example:
// for i in T.Parallel(1024):
// A_local[i] = A_global[i]
// Here, A_local is a register-local buffer held independently by each
// thread, so explicit thread binding is not required.
bool store_into_local = false;
PostOrderVisit(root, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
if (store->buffer.scope() == "local") {
store_into_local = true;
}
// if the case is like:
// for i in T.Parallel(1024):
// A_local[i] = B_global[i]
// A_frag[i] = A_global[i]
// exception will be raise in Parallel::LayoutInference
auto root = tvm::ffi::GetRef<For>(op);
// This check is a workaround to support T.Parallel for local buffers.
// For example:
// for i in T.Parallel(1024):
// A_local[i] = A_global[i]
// Here, A_local is a register-local buffer held independently by each
// thread, so explicit thread binding is not required.
bool store_into_local = false;
PostOrderVisit(root, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
if (store->buffer.scope() == "local") {
store_into_local = true;
}
});
// This check if for the loop that only manuplates "local" buffers,
// for i in T.Parallel(1024):
// A_local[i] = B_local[i]
// Though this might be illegal
// We use PostOrderVisit to detect whether the loop only manuplates
// "local" buffers, which indicates register usage and justifies skipping
// thread binding.
bool local_register_only = true;
PostOrderVisit(root, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
if (store->buffer.scope() != "local") {
local_register_only = false;
}
} else if (const auto *load = obj.as<BufferLoadNode>()) {
if (load->buffer.scope() != "local") {
local_register_only = false;
}
// if the case is like:
// for i in T.Parallel(1024):
// A_local[i] = B_global[i]
// A_frag[i] = A_global[i]
// exception will be raise in Parallel::LayoutInference
}
});
// This check if for the loop that only manuplates "local" buffers,
// for i in T.Parallel(1024):
// A_local[i] = B_local[i]
// Though this might be illegal
// We use PostOrderVisit to detect whether the loop only manuplates
// "local" buffers, which indicates register usage and justifies skipping
// thread binding.
bool local_register_only = true;
PostOrderVisit(root, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
if (store->buffer.scope() != "local") {
local_register_only = false;
}
});
} else if (const auto *load = obj.as<BufferLoadNode>()) {
if (load->buffer.scope() != "local") {
local_register_only = false;
}
}
});
auto loop_layout = result_.for_map[root];
// FIXME: tell in-Parallel and out-of-Parallel `local`s apart
// NOTE(lei): a bit ugly, we should rethink about this part in future.
bool parallel_loop =
!skip_thread_partition_ && !local_register_only && !store_into_local;
auto loop_layout = result_.for_map[root];
// FIXME: tell in-Parallel and out-of-Parallel `local`s apart
// NOTE(lei): a bit ugly, we should rethink about this part in future.
bool parallel_loop =
!skip_thread_partition_ && !local_register_only && !store_into_local;
if (parallel_loop) {
for_node =
PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
if (parallel_loop) {
for_node =
PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
}
// If none thread bindings are provided, partition the loop
bool has_non_local = false;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (const auto *load = obj.as<BufferLoadNode>()) {
String scope = load->buffer.scope();
if (scope != "local" && scope != "local.fragment") {
has_non_local = true;
}
} else if (const auto *store = obj.as<BufferStoreNode>()) {
String scope = store->buffer.scope();
if (scope != "local" && scope != "local.fragment") {
has_non_local = true;
}
}
// If none thread bindings are provided, partition the loop
bool has_non_local = false;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (const auto *load = obj.as<BufferLoadNode>()) {
String scope = load->buffer.scope();
if (scope != "local" && scope != "local.fragment") {
has_non_local = true;
}
} else if (const auto *store = obj.as<BufferStoreNode>()) {
String scope = store->buffer.scope();
if (scope != "local" && scope != "local.fragment") {
has_non_local = true;
}
});
// Workaround: if reducer is presented, don't vectorize loop
// Best solution should be isolate reduction axis out of vectorization
bool has_reducer = false;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (!has_reducer)
if (const auto *store = obj.as<BufferStoreNode>()) {
has_reducer = reducer_info.count(store->buffer->data) != 0;
}
});
// Workaround: if reducer is presented, don't vectorize loop
// Best solution should be isolate reduction axis out of vectorization
bool has_reducer = false;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (!has_reducer)
if (const auto *store = obj.as<BufferStoreNode>()) {
has_reducer = reducer_info.count(store->buffer->data) != 0;
}
});
// If a cast operation exists, vectorization may still be required
bool has_cast_operations = false;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (const auto *cast = obj.as<CastNode>()) {
// Check if this is a non-reducer store with Cast operation
DataType src_type = cast->value.dtype();
DataType dst_type = cast->dtype;
bool src_ok = src_type.is_float() || src_type.is_bfloat() ||
src_type.is_float8_e4m3() || src_type.is_float8_e5m2();
bool dst_ok = dst_type.is_float() || dst_type.is_bfloat() ||
dst_type.is_float8_e4m3() || dst_type.is_float8_e5m2();
if (src_ok && dst_ok && TargetIsCuda(Target::Current())) {
has_cast_operations = true;
}
});
// If a cast operation exists, vectorization may still be required
bool has_cast_operations = false;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (const auto *cast = obj.as<CastNode>()) {
// Check if this is a non-reducer store with Cast operation
DataType src_type = cast->value.dtype();
DataType dst_type = cast->dtype;
bool src_ok = src_type.is_float() || src_type.is_bfloat() ||
src_type.is_float8_e4m3() || src_type.is_float8_e5m2();
bool dst_ok = dst_type.is_float() || dst_type.is_bfloat() ||
dst_type.is_float8_e4m3() || dst_type.is_float8_e5m2();
if (src_ok && dst_ok && TargetIsCuda(Target::Current())) {
has_cast_operations = true;
}
});
if ((has_non_local || has_cast_operations) && !has_reducer) {
for_node = VectorizeLoop(for_node, analyzer_);
}
});
if (result_.predicate_map.count(root) && parallel_loop) {
return IfThenElse(result_.predicate_map[root], for_node);
} else {
return for_node;
}
if ((has_non_local || has_cast_operations) && !has_reducer) {
for_node = VectorizeLoop(for_node, saved_analyzer.get());
}
if (result_.predicate_map.count(root) && parallel_loop) {
return IfThenElse(result_.predicate_map[root], for_node);
} else {
return for_node;
}
return for_node;
}
Stmt VisitStmt_(const AttrStmtNode *op) final {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment