Unverified Commit a7c9a8b9 authored by Siyuan Feng's avatar Siyuan Feng Committed by GitHub
Browse files

Refactor to support upstream tvm (#595)

**Summarize part of the rebase pr:**

1. **Support T.thread_return() → CUDA return syntax**  
   Added support for translating `T.thread_return()` to CUDA's native `return` statement.

2. **Dynamic type support for function inputs**  
   Functions now accept dynamically typed parameters using `typing`:
   ```python
   dyn_type = T.int32 or T.float
   @T.prim_func
   def main(
       a: dyn_type,
   )
   ```

3. **Device Function Codegen**  
   Added support for generating `__device__` functions in CUDA:
   ```python
   @I.ir_module
   class Module:
       @T.prim_func(private=True)
       def add(a: T.int32, b: T.int32) -> T.int32:
           return a + b

       @T.prim_func
       def main(
           A: T.Buffer((128, 128), "int32"),
           B: T.Buffer((128, 128), "int32"),
           C: T.Buffer((128, 128), "int32"),
       ):
           T.func_attr({"global_symbol": "main"})
           length: T.int32 = Module.add(64, 64)  # Host call
           for bx in...
parent 8edd6941
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*! /*!
* \file clasuter_planning.cc * \file clasuter_planning.cc
* \brief Plan the cluster for GPU(sm90+) blocks * \brief Plan the cluster for GPU(sm90+) blocks
*/ */
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h> #include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
...@@ -132,8 +115,10 @@ tvm::transform::Pass ClusterPlanning() { ...@@ -132,8 +115,10 @@ tvm::transform::Pass ClusterPlanning() {
return CreatePrimFuncPass(pass_func, 0, "tl.ClusterPlanning", {}); return CreatePrimFuncPass(pass_func, 0, "tl.ClusterPlanning", {});
} }
TVM_REGISTER_GLOBAL("tl.transform.ClusterPlanning") TVM_FFI_STATIC_INIT_BLOCK({
.set_body_typed(ClusterPlanning); namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.ClusterPlanning", ClusterPlanning);
});
} // namespace transform } // namespace transform
} // namespace tir } // namespace tir
......
...@@ -599,7 +599,7 @@ public: ...@@ -599,7 +599,7 @@ public:
return Scalarize(GetRef<Stmt>(op)); return Scalarize(GetRef<Stmt>(op));
} }
Stmt then_case = this->VisitStmt(op->then_case); Stmt then_case = this->VisitStmt(op->then_case);
Optional<Stmt> else_case = NullOpt; Optional<Stmt> else_case = std::nullopt;
if (op->else_case) { if (op->else_case) {
else_case = this->VisitStmt(op->else_case.value()); else_case = this->VisitStmt(op->else_case.value());
} }
...@@ -681,10 +681,6 @@ public: ...@@ -681,10 +681,6 @@ public:
stmt = Substitute(stmt, {{var_, idx}}); stmt = Substitute(stmt, {{var_, idx}});
return For(idx, IntImm(var_->dtype, 0), var_lanes_, ForKind::kSerial, stmt); return For(idx, IntImm(var_->dtype, 0), var_lanes_, ForKind::kSerial, stmt);
} }
// ProducerStore
Stmt VisitStmt_(const ProducerStoreNode *op) final {
LOG(FATAL) << "ProducerProvide cannot appear in a TIR PrimFunc";
}
private: private:
// analyzer // analyzer
......
#include "../op/builtin.h" #include "../op/builtin.h"
#include <tvm/runtime/registry.h> #include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/data_type_rewriter.h> #include <tvm/tir/data_type_rewriter.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
...@@ -85,8 +86,11 @@ tvm::transform::Pass ConfigIndexBitwidth() { ...@@ -85,8 +86,11 @@ tvm::transform::Pass ConfigIndexBitwidth() {
return CreatePrimFuncPass(pass_func, 0, "tl.ConfigIndexBitwidth", {}); return CreatePrimFuncPass(pass_func, 0, "tl.ConfigIndexBitwidth", {});
} }
TVM_REGISTER_GLOBAL("tl.transform.ConfigIndexBitwidth") TVM_FFI_STATIC_INIT_BLOCK({
.set_body_typed(ConfigIndexBitwidth); namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.ConfigIndexBitwidth",
ConfigIndexBitwidth);
});
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
#include "./storage_access.h" #include "./storage_access.h"
#include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h" #include "arith/ir_visitor_with_analyzer.h"
#include <tvm/runtime/registry.h> #include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h> #include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
...@@ -115,8 +116,11 @@ tvm::transform::Pass EliminateStorageSyncForMBarrier() { ...@@ -115,8 +116,11 @@ tvm::transform::Pass EliminateStorageSyncForMBarrier() {
{}); {});
} }
TVM_REGISTER_GLOBAL("tl.transform.EliminateStorageSyncForMBarrier") TVM_FFI_STATIC_INIT_BLOCK({
.set_body_typed(EliminateStorageSyncForMBarrier); namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.EliminateStorageSyncForMBarrier",
EliminateStorageSyncForMBarrier);
});
} // namespace transform } // namespace transform
} // namespace tl } // namespace tl
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_mutator_with_analyzer.h"
#include "tir/transforms/ir_utils.h" #include "tir/transforms/ir_utils.h"
#include <tvm/arith/iter_affine_map.h> #include <tvm/arith/iter_affine_map.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h> #include <tvm/tir/analysis.h>
#include <tvm/tir/data_type_rewriter.h> #include <tvm/tir/data_type_rewriter.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
...@@ -352,12 +353,7 @@ private: ...@@ -352,12 +353,7 @@ private:
}; };
PrimFunc FlattenBufferRewriter(PrimFunc f) { PrimFunc FlattenBufferRewriter(PrimFunc f) {
// Only apply this pass to TIR that is not from TE schedules return BufferFlattener::Flatten(f);
if (!IsFromLegacyTESchedule(f)) {
return BufferFlattener::Flatten(f);
} else {
return f;
}
} }
using namespace tir::transform; using namespace tir::transform;
...@@ -368,7 +364,10 @@ tvm::transform::Pass FlattenBuffer() { ...@@ -368,7 +364,10 @@ tvm::transform::Pass FlattenBuffer() {
return CreatePrimFuncPass(pass_func, 0, "tl.FlattenBuffer", {}); return CreatePrimFuncPass(pass_func, 0, "tl.FlattenBuffer", {});
} }
TVM_REGISTER_GLOBAL("tl.transform.FlattenBuffer").set_body_typed(FlattenBuffer); TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.FlattenBuffer", FlattenBuffer);
});
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
* \brief Legalize the program from frontend * \brief Legalize the program from frontend
*/ */
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
...@@ -88,8 +89,10 @@ Pass FrontendLegalize() { ...@@ -88,8 +89,10 @@ Pass FrontendLegalize() {
return CreatePrimFuncPass(pass_func, 0, "tl.FrontendLegalize", {}); return CreatePrimFuncPass(pass_func, 0, "tl.FrontendLegalize", {});
} }
TVM_REGISTER_GLOBAL("tl.transform.FrontendLegalize") TVM_FFI_STATIC_INIT_BLOCK({
.set_body_typed(FrontendLegalize); namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.FrontendLegalize", FrontendLegalize);
});
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
* \brief Bind the If Stmt to each Stmt in SeqStmt * \brief Bind the If Stmt to each Stmt in SeqStmt
*/ */
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h> #include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
...@@ -80,7 +81,10 @@ tvm::transform::Pass IfStmtBinding() { ...@@ -80,7 +81,10 @@ tvm::transform::Pass IfStmtBinding() {
return CreatePrimFuncPass(pass_func, 0, "tl.IfStmtBinding", {}); return CreatePrimFuncPass(pass_func, 0, "tl.IfStmtBinding", {});
} }
TVM_REGISTER_GLOBAL("tl.transform.IfStmtBinding").set_body_typed(IfStmtBinding); TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.IfStmtBinding", IfStmtBinding);
});
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
* \brief Inject fence between generic and async proxies (sm90+) * \brief Inject fence between generic and async proxies (sm90+)
*/ */
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h> #include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
...@@ -193,8 +194,10 @@ tvm::transform::Pass InjectFenceProxy() { ...@@ -193,8 +194,10 @@ tvm::transform::Pass InjectFenceProxy() {
return CreatePrimFuncPass(pass_func, 0, "tl.InjectFenceProxy", {}); return CreatePrimFuncPass(pass_func, 0, "tl.InjectFenceProxy", {});
} }
TVM_REGISTER_GLOBAL("tl.transform.InjectFenceProxy") TVM_FFI_STATIC_INIT_BLOCK({
.set_body_typed(InjectFenceProxy); namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectFenceProxy", InjectFenceProxy);
});
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
* \brief Transform annotated loops into pipelined one that parallelize * \brief Transform annotated loops into pipelined one that parallelize
* producers and consumers * producers and consumers
*/ */
#include <tvm/ffi/reflection/registry.h>
#include <tvm/target/target.h> #include <tvm/target/target.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
...@@ -737,7 +738,7 @@ private: ...@@ -737,7 +738,7 @@ private:
} }
if (!is_unit_loop) { if (!is_unit_loop) {
Map<String, ObjectRef> preserved_annotations; Map<String, Any> preserved_annotations;
for (const auto &kv : pipeline_loop_->annotations) { for (const auto &kv : pipeline_loop_->annotations) {
const String &key = kv.first; const String &key = kv.first;
if (kv.first != tir::attr::software_pipeline_stage && if (kv.first != tir::attr::software_pipeline_stage &&
...@@ -748,7 +749,7 @@ private: ...@@ -748,7 +749,7 @@ private:
} }
new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent, new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent,
unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind,
std::move(new_loop), NullOpt, preserved_annotations); std::move(new_loop), std::nullopt, preserved_annotations);
} }
// Update producer heads in the global async states. // Update producer heads in the global async states.
for (const auto &[stage_id, state] : async_states_local) { for (const auto &[stage_id, state] : async_states_local) {
...@@ -955,7 +956,7 @@ private: ...@@ -955,7 +956,7 @@ private:
std::unordered_set<int> pipeline_async_stages; std::unordered_set<int> pipeline_async_stages;
if (auto annot = if (auto annot =
op->annotations.Get(tir::attr::software_pipeline_async_stages)) { op->annotations.Get(tir::attr::software_pipeline_async_stages)) {
for (auto s : Downcast<Array<Integer>>(annot)) { for (auto s : Downcast<Array<Integer>>(annot.value())) {
pipeline_async_stages.insert(s->value); pipeline_async_stages.insert(s->value);
} }
} }
...@@ -1038,8 +1039,11 @@ tir::transform::Pass InjectSoftwarePipeline() { ...@@ -1038,8 +1039,11 @@ tir::transform::Pass InjectSoftwarePipeline() {
return CreatePrimFuncPass(pass_func, 0, "tl.InjectSoftwarePipeline", {}); return CreatePrimFuncPass(pass_func, 0, "tl.InjectSoftwarePipeline", {});
} }
TVM_REGISTER_GLOBAL("tl.transform.InjectSoftwarePipeline") TVM_FFI_STATIC_INIT_BLOCK({
.set_body_typed(InjectSoftwarePipeline); namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectSoftwarePipeline",
InjectSoftwarePipeline);
});
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* \brief Replace copy from global to shared with async copy * \brief Replace copy from global to shared with async copy
* \file inject_ptx_async_copy.cc * \file inject_ptx_async_copy.cc
*/ */
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h> #include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
...@@ -231,8 +232,10 @@ tvm::transform::Pass InjectPTXAsyncCopy() { ...@@ -231,8 +232,10 @@ tvm::transform::Pass InjectPTXAsyncCopy() {
return CreatePrimFuncPass(pass_func, 0, "tl.InjectPTXAsyncCopy", {}); return CreatePrimFuncPass(pass_func, 0, "tl.InjectPTXAsyncCopy", {});
} }
TVM_REGISTER_GLOBAL("tl.transform.InjectPTXAsyncCopy") TVM_FFI_STATIC_INIT_BLOCK({
.set_body_typed(InjectPTXAsyncCopy); namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectPTXAsyncCopy", InjectPTXAsyncCopy);
});
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
*/ */
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h> #include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
...@@ -306,8 +307,10 @@ tvm::transform::Pass InjectTmaBarrier() { ...@@ -306,8 +307,10 @@ tvm::transform::Pass InjectTmaBarrier() {
return CreatePrimFuncPass(pass_func, 0, "tl.InjectTmaBarrier", {}); return CreatePrimFuncPass(pass_func, 0, "tl.InjectTmaBarrier", {});
} }
TVM_REGISTER_GLOBAL("tl.transform.InjectTmaBarrier") TVM_FFI_STATIC_INIT_BLOCK({
.set_body_typed(InjectTmaBarrier); namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectTmaBarrier", InjectTmaBarrier);
});
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
* \brief infer the fragment/shared memory layout * \brief infer the fragment/shared memory layout
*/ */
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/index_map.h> #include <tvm/tir/index_map.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
...@@ -138,11 +139,10 @@ public: ...@@ -138,11 +139,10 @@ public:
if (layout_map.count(buffer)) { if (layout_map.count(buffer)) {
// If replicate size of this buffer is greater than the old one // If replicate size of this buffer is greater than the old one
if (buffer.scope() == "local.fragment" && if (buffer.scope() == "local.fragment" &&
level != InferLevel::kStrict && level != InferLevel::kStrict) {
!strict_layout_map.count(buffer)) { const FragmentNode *dst_layout = layout.as<FragmentNode>();
const FragmentNode *dst_layout = layout.as<Fragment>().get();
const FragmentNode *src_layout = const FragmentNode *src_layout =
layout_map[buffer].as<Fragment>().get(); layout_map[buffer].as<FragmentNode>();
if (as_const_int(dst_layout->ReplicateExtent()) && if (as_const_int(dst_layout->ReplicateExtent()) &&
as_const_int(src_layout->ReplicateExtent()) && as_const_int(src_layout->ReplicateExtent()) &&
(*as_const_int(dst_layout->ReplicateExtent()) > (*as_const_int(dst_layout->ReplicateExtent()) >
...@@ -313,7 +313,7 @@ private: ...@@ -313,7 +313,7 @@ private:
auto var = call->args[1].as<Var>().value(); auto var = call->args[1].as<Var>().value();
return buffer_data_to_buffer_[var]; return buffer_data_to_buffer_[var];
} }
return NullOpt; return std::nullopt;
} }
void addToUseList(const Buffer &buffer) { void addToUseList(const Buffer &buffer) {
...@@ -354,11 +354,9 @@ private: ...@@ -354,11 +354,9 @@ private:
} }
if (op->annotations.count(attr::kLayoutMap)) { if (op->annotations.count(attr::kLayoutMap)) {
// Check if the layout map is Map<Var, Layout> // Check if the layout map is Map<Var, Layout>
auto map = op->annotations.Get(attr::kLayoutMap).as<Map<Var, Layout>>(); auto map =
ICHECK(map.defined()) << "layout map is not defined"; op->annotations.Get(attr::kLayoutMap)->as<Map<Var, Layout>>().value();
ICHECK(map.value().defined()) << "layout map is not defined"; for (const auto &[var, layout] : map) {
for (const auto &[var, layout] : map.value()) {
ICHECK(buffer_data_to_buffer_.count(var)) ICHECK(buffer_data_to_buffer_.count(var))
<< "buffer " << var << " is not found in the block"; << "buffer " << var << " is not found in the block";
auto buffer = buffer_data_to_buffer_[var]; auto buffer = buffer_data_to_buffer_[var];
...@@ -519,8 +517,10 @@ tvm::transform::Pass LayoutInference() { ...@@ -519,8 +517,10 @@ tvm::transform::Pass LayoutInference() {
return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {}); return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {});
} }
TVM_REGISTER_GLOBAL("tl.transform.LayoutInference") TVM_FFI_STATIC_INIT_BLOCK({
.set_body_typed(LayoutInference); namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LayoutInference", LayoutInference);
});
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
* \brief legalize safe memory access * \brief legalize safe memory access
*/ */
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
...@@ -313,7 +314,7 @@ private: ...@@ -313,7 +314,7 @@ private:
} }
if (op->annotations.count(attr::kPaddingMap)) { if (op->annotations.count(attr::kPaddingMap)) {
auto map = op->annotations.Get(attr::kPaddingMap) auto map = op->annotations.Get(attr::kPaddingMap)
.as<Map<Var, PrimExpr>>() ->as<Map<Var, PrimExpr>>()
.value(); .value();
for (const auto &[var, padding] : map) { for (const auto &[var, padding] : map) {
ICHECK(buffer_data_to_buffer_.count(var)) ICHECK(buffer_data_to_buffer_.count(var))
...@@ -353,8 +354,11 @@ tvm::transform::Pass LegalizeSafeMemoryAccess() { ...@@ -353,8 +354,11 @@ tvm::transform::Pass LegalizeSafeMemoryAccess() {
} }
// Register the pass globally so it can be used in the compilation pipeline // Register the pass globally so it can be used in the compilation pipeline
TVM_REGISTER_GLOBAL("tl.transform.LegalizeSafeMemoryAccess") TVM_FFI_STATIC_INIT_BLOCK({
.set_body_typed(LegalizeSafeMemoryAccess); namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LegalizeSafeMemoryAccess",
LegalizeSafeMemoryAccess);
});
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
* \brief infer the fragment/shared memory layout * \brief infer the fragment/shared memory layout
*/ */
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
...@@ -88,8 +89,11 @@ tvm::transform::Pass LegalizeVectorizedLoop() { ...@@ -88,8 +89,11 @@ tvm::transform::Pass LegalizeVectorizedLoop() {
} }
// Register the pass globally so it can be used in the compilation pipeline // Register the pass globally so it can be used in the compilation pipeline
TVM_REGISTER_GLOBAL("tl.transform.LegalizeVectorizedLoop") TVM_FFI_STATIC_INIT_BLOCK({
.set_body_typed(LegalizeVectorizedLoop); namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LegalizeVectorizedLoop",
LegalizeVectorizedLoop);
});
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <cstdint> #include <cstdint>
#include <tvm/arith/iter_affine_map.h> #include <tvm/arith/iter_affine_map.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
...@@ -145,9 +146,7 @@ private: ...@@ -145,9 +146,7 @@ private:
const DataType &access_type = buffer->dtype; const DataType &access_type = buffer->dtype;
// i // 2, i % 8 can also be vectorized as factor 16 // i // 2, i % 8 can also be vectorized as factor 16
int max_vector_size = vector_load_bits_max_ / access_type.bits(); int max_vector_size = vector_load_bits_max_ / access_type.bits();
if (access_type.is_e4m3_float8() or access_type.is_e5m2_float8()) {
max_vector_size = 1; // [temporarily] do not vectorize float8
}
// so we should disable this GCD optimization // so we should disable this GCD optimization
max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value); max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);
...@@ -532,8 +531,11 @@ tvm::transform::Pass LoopVectorizeDynamic() { ...@@ -532,8 +531,11 @@ tvm::transform::Pass LoopVectorizeDynamic() {
} }
// Register the pass globally so it can be used in the compilation pipeline // Register the pass globally so it can be used in the compilation pipeline
TVM_REGISTER_GLOBAL("tl.transform.LoopVectorizeDynamic") TVM_FFI_STATIC_INIT_BLOCK({
.set_body_typed(LoopVectorizeDynamic); namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LoopVectorizeDynamic",
LoopVectorizeDynamic);
});
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file lower_device_kernel_launch.cc
* \brief Split device function from host.
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/transform.h>
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
namespace tvm {
namespace tl {
using namespace tir;
namespace {
struct KernelInfo {
// The device on which the PrimFunc runs
Target target;
// The externally visible symbol which may refer to the PrimFunc
// when launching a device kernel.
String global_symbol;
// The parameters accepted by the PrimFunc. Used to rewrite
// `launch_args` to be in terms of the calling scope.
Array<Var> params;
// The launch parameters that should annotate the PrimFunc, if the
// kernel is ever called from the host.
Array<String> launch_params;
// Additional arguments which must be provided to the host-side
// PackedFunc. These may be in terms of the function's parameters
// (e.g. a function that computes the average of `N` elements, and
// which must be launched with `N` CUDA threads).
Array<PrimExpr> launch_args;
// The extent of each thread
Map<String, PrimExpr> thread_extent;
// The amount of dynamic shared memory used
Optional<PrimExpr> dyn_shmem_size{std::nullopt};
};
/*!
* \brief Visitor class to collect device-side program information.
*/
class DeviceInfoCollector : public StmtVisitor {
public:
static KernelInfo Collect(const GlobalVar &gvar, const PrimFunc &func) {
DeviceInfoCollector collector;
collector.info_.target =
func->GetAttr<Target>(tvm::attr::kTarget).value().WithoutHost();
collector.info_.params = func->params;
collector(func->body);
// The dynamic shared memory is required to be the last of the
// kernel launch parameters
if (collector.dyn_shmem_size) {
collector.info_.launch_params.push_back(
tvm::runtime::launch_param::kUseDynamicSharedMemoryTag);
}
collector.info_.global_symbol =
func->GetAttr<String>(tvm::attr::kGlobalSymbol)
.value_or(gvar->name_hint);
collector.info_.launch_args = collector.info_.launch_params.Map(
[&](const auto &param) { return collector.GetArgument(param); });
collector.info_.dyn_shmem_size = collector.dyn_shmem_size;
collector.info_.thread_extent = collector.thread_extent;
return collector.info_;
}
private:
PrimExpr GetArgument(const String &launch_param) const {
if (launch_param ==
tvm::runtime::launch_param::kUseDynamicSharedMemoryTag) {
CHECK(dyn_shmem_size.defined())
<< "Compute kernel requires launch parameter \"" << launch_param
<< "\", but PrimFunc did not contain Allocate node with shared "
"dynamic scope.";
return dyn_shmem_size.value();
}
auto extent = thread_extent.Get(launch_param);
CHECK(extent) << "Compute kernel requires launch parameter \""
<< launch_param
<< "\", but PrimFunc does not contain AttrStmt \""
<< tir::attr::thread_extent
<< "\" defining this thread extent";
return extent.value();
}
void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
ICHECK_NE(iv->thread_tag.length(), 0U);
// thread_extent can appear multiple times
// use the first appearance as def.
if (!defined_thread.count(iv.get())) {
defined_thread.insert(iv.get());
info_.launch_params.push_back(iv->thread_tag);
thread_extent.Set(iv->thread_tag, op->value);
}
}
StmtVisitor::VisitStmt_(op);
}
void VisitStmt_(const AllocateNode *op) final {
auto storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var));
if (storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == ".dyn") {
ICHECK(!dyn_shmem_size.defined())
<< "Only one dynamic shared memory allocation is allowed.";
ICHECK_GT(op->extents.size(), 0);
PrimExpr dyn_size = Integer(1);
for (const auto &extent : op->extents) {
dyn_size *= extent;
}
dyn_size *= op->dtype.bytes() * op->dtype.lanes();
dyn_shmem_size = dyn_size;
}
StmtVisitor::VisitStmt_(op);
}
// The collected results
KernelInfo info_;
// recording what thread axis have been visited.
std::unordered_set<const IterVarNode *> defined_thread;
// The extent of each thread
Map<String, PrimExpr> thread_extent;
// The amount of dynamic shared memory used
Optional<PrimExpr> dyn_shmem_size{std::nullopt};
};
class ReturnRemover : public StmtExprMutator {
public:
static Stmt Apply(const Stmt &stmt) {
ReturnRemover mutator;
return mutator(stmt);
}
private:
using Parent = StmtExprMutator;
Stmt VisitStmt_(const EvaluateNode *op) override {
if (auto *call = op->value.as<CallNode>()) {
if (call->op.same_as(builtin::ret())) {
ICHECK_EQ(call->args.size(), 1);
auto as_int = call->args[0].as<IntImmNode>();
ICHECK(as_int && as_int->value == 0)
<< "Device kernel may only contain successful return, T.ret(0)";
return Evaluate(0);
}
}
return Parent::VisitStmt_(op);
}
PrimExpr VisitExpr_(const CallNode *op) override {
if (op->op.same_as(builtin::ret())) {
LOG(FATAL) << "Call to builtin::ret() should only appear within an "
"Evaluate node";
}
return Parent::VisitExpr_(op);
}
};
} // namespace
class DeviceKernelMutator : public StmtExprMutator {
public:
using Parent = StmtExprMutator;
explicit DeviceKernelMutator(
std::unordered_map<const GlobalVarNode *, KernelInfo> device_info_map)
: device_info_map_(std::move(device_info_map)) {}
PrimFunc RewriteKernelLaunchSite(const GlobalVar &gvar, PrimFunc func) {
ICHECK(!current_target_.defined());
auto it = device_info_map_.find(gvar.get());
ICHECK(it != device_info_map_.end());
current_target_ = it->second.target;
auto body = VisitStmt(func->body);
if (!body.same_as(func->body)) {
func.CopyOnWrite()->body = body;
}
current_target_ = std::nullopt;
return func;
}
PrimFunc UpdateKernelAttributes(const GlobalVar &gvar, PrimFunc func) const {
bool is_kernel_launch = device_kernel_launch_.count(gvar.get());
bool is_call_extern = extern_function_call_.count(gvar.get());
CHECK(!is_kernel_launch || !is_call_extern)
<< "Function " << gvar << " has multiple callees, "
<< "and would need to be lowered into a call_extern at some call "
"sites, "
<< "and a device kernel launch at others. "
<< "This case is not yet supported.";
if (is_kernel_launch || is_call_extern) {
func =
WithAttr(std::move(func), tvm::tir::attr::kIsGlobalFunc, Bool(true));
}
if (is_kernel_launch) {
const auto &info = device_info_map_.at(gvar.get());
// Kernel launches provide an int32 error code to the caller,
// but do not accept any return type from the callee.
{
auto write_ptr = func.CopyOnWrite();
write_ptr->ret_type = VoidType();
write_ptr->body = ReturnRemover::Apply(write_ptr->body);
}
func =
WithAttrs(std::move(func),
{{tvm::attr::kCallingConv,
Integer(tvm::CallingConv::kDeviceKernelLaunch)},
{tvm::tir::attr::kKernelLaunchParams, info.launch_params},
{tvm::attr::kGlobalSymbol, info.global_symbol}});
}
// @lei: workaround as we may require c host codegen, so we need to set the
// global symbol for cpu backend.
func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint);
const auto &info = device_info_map_.at(gvar.get());
const auto &thread_extent = info.thread_extent;
func = WithAttr(std::move(func), "thread_extent", thread_extent);
if (info.dyn_shmem_size.defined()) {
func = WithAttr(std::move(func), "dyn_shared_memory_buf",
info.dyn_shmem_size.value());
}
return func;
}
private:
PrimExpr VisitExpr_(const CallNode *op) override {
auto node = Downcast<Call>(Parent::VisitExpr_(op));
auto *gvar = op->op.as<GlobalVarNode>();
if (!gvar)
return std::move(node);
auto it = device_info_map_.find(gvar);
ICHECK(it != device_info_map_.end())
<< "CallNode attempted subroutine call to " << gvar->name_hint
<< ", but " << gvar->name_hint << " did not appear within the IRModule";
const KernelInfo &dev_info = it->second;
auto caller_target = current_target_.value();
auto callee_target = dev_info.target;
bool same_target = caller_target->str() == callee_target->str();
if (same_target) {
// Calls within the same target may be handled at codegen time
// as internal subroutine calls.
return std::move(node);
}
bool same_device_type = caller_target->GetTargetDeviceType() ==
callee_target->GetTargetDeviceType();
if (same_device_type) {
// Calls to another target using the same device (e.g. LLVM
// calling a custom TIRToRuntime target) do not require a kernel
// launch, but need to be replaced with call_extern.
extern_function_call_.insert(gvar);
Array<PrimExpr> args;
args.push_back(StringImm(gvar->name_hint));
for (const auto &arg : node->args) {
args.push_back(arg);
}
return Call(node->dtype, builtin::call_extern(), args);
}
ICHECK(dev_info.launch_params.defined())
<< "CallNode attempted kernel launch to " << gvar->name_hint
<< " on target " << dev_info.target << ", but subroutine "
<< gvar->name_hint
<< " did not have the tir::attr::kKernelLaunchParams attribute "
<< "required for cross-target kernel launch";
// Collected kernel information may be in terms of the callee's
// arguments, but we need expressions for them in terms of the
// caller's parameters. The param_map allows substitution of
// parameter values into the thread extents, to generate
// expressions that are valid within the caller.
Map<Var, PrimExpr> param_map = [&]() {
Map<Var, PrimExpr> param_map;
CHECK_EQ(node->args.size(), dev_info.params.size())
<< "Function " << gvar->name_hint << " accepts "
<< dev_info.params.size()
<< " arguments as input, but is called using " << node->args.size()
<< " arguments";
for (size_t i = 0; i < node->args.size(); i++) {
param_map.Set(dev_info.params[i], node->args[i]);
}
return param_map;
}();
device_kernel_launch_.insert(gvar);
Array<PrimExpr> call_args;
call_args.push_back(StringImm(dev_info.global_symbol));
for (PrimExpr arg : node->args) {
call_args.push_back(arg);
}
for (const auto &launch_arg : dev_info.launch_args) {
call_args.push_back(Substitute(launch_arg, param_map));
}
auto dtype = node->dtype.is_void() ? DataType::Int(32) : node->dtype;
return Call(dtype, builtin::tvm_call_packed(), call_args);
}
Optional<Target> current_target_;
std::unordered_map<const GlobalVarNode *, KernelInfo> device_info_map_;
std::unordered_set<const GlobalVarNode *> device_kernel_launch_;
std::unordered_set<const GlobalVarNode *> extern_function_call_;
};
namespace transform {
tvm::transform::Pass LowerDeviceKernelLaunch() {
auto pass_func = [](IRModule mod,
tir::transform::PassContext ctx) -> IRModule {
auto mutator = [&mod]() {
std::unordered_map<const GlobalVarNode *, KernelInfo> device_info_map;
for (const auto &[gvar, base_func] : mod->functions) {
if (auto prim_func = base_func.as<PrimFunc>()) {
device_info_map[gvar.get()] =
DeviceInfoCollector::Collect(gvar, prim_func.value());
}
}
return DeviceKernelMutator(std::move(device_info_map));
}();
{
IRModule updates;
for (const auto &[gvar, base_func] : mod->functions) {
if (auto *ptr = base_func.as<PrimFuncNode>()) {
auto prim_func =
mutator.RewriteKernelLaunchSite(gvar, GetRef<PrimFunc>(ptr));
if (!prim_func.same_as(base_func)) {
updates->Add(gvar, prim_func);
}
}
}
if (updates->functions.size()) {
mod.CopyOnWrite()->Update(updates);
}
}
{
IRModule updates;
for (const auto &[gvar, base_func] : mod->functions) {
if (auto *ptr = base_func.as<PrimFuncNode>()) {
auto prim_func =
mutator.UpdateKernelAttributes(gvar, GetRef<PrimFunc>(ptr));
if (!prim_func.same_as(base_func)) {
updates->Add(gvar, prim_func);
}
}
}
if (updates->functions.size()) {
mod.CopyOnWrite()->Update(updates);
}
}
return mod;
};
return tvm::transform::CreateModulePass(pass_func, 0,
"tl.LowerDeviceKernelLaunch", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerDeviceKernelLaunch",
LowerDeviceKernelLaunch);
});
} // namespace transform
} // namespace tl
} // namespace tvm
...@@ -22,7 +22,8 @@ ...@@ -22,7 +22,8 @@
* \brief Lower the special device storage access. * \brief Lower the special device storage access.
*/ */
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h> #include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/target/target_info.h> #include <tvm/target/target_info.h>
#include <tvm/tir/buffer.h> #include <tvm/tir/buffer.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
...@@ -141,8 +142,11 @@ Pass LowerDeviceStorageAccessInfo() { ...@@ -141,8 +142,11 @@ Pass LowerDeviceStorageAccessInfo() {
{}); {});
} }
TVM_REGISTER_GLOBAL("tl.transform.LowerDeviceStorageAccessInfo") TVM_FFI_STATIC_INIT_BLOCK({
.set_body_typed(LowerDeviceStorageAccessInfo); namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerDeviceStorageAccessInfo",
LowerDeviceStorageAccessInfo);
});
} // namespace transform } // namespace transform
} // namespace tl } // namespace tl
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
* \brief Lower Hopper intrinsics cuda GPU(sm90+) * \brief Lower Hopper intrinsics cuda GPU(sm90+)
*/ */
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h> #include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
...@@ -149,8 +150,10 @@ tvm::transform::Pass LowerHopperIntrin() { ...@@ -149,8 +150,10 @@ tvm::transform::Pass LowerHopperIntrin() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerHopperIntrin", {}); return CreatePrimFuncPass(pass_func, 0, "tl.LowerHopperIntrin", {});
} }
TVM_REGISTER_GLOBAL("tl.transform.LowerHopperIntrin") TVM_FFI_STATIC_INIT_BLOCK({
.set_body_typed(LowerHopperIntrin); namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerHopperIntrin", LowerHopperIntrin);
});
#endif // (CUDA_MAJOR_VERSION >= 12) #endif // (CUDA_MAJOR_VERSION >= 12)
} // namespace tl } // namespace tl
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
* \brief Lower L2 persistent annotation * \brief Lower L2 persistent annotation
*/ */
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h> #include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
...@@ -98,8 +99,10 @@ tvm::transform::Pass LowerL2Persistent() { ...@@ -98,8 +99,10 @@ tvm::transform::Pass LowerL2Persistent() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerL2Persistent", {}); return CreatePrimFuncPass(pass_func, 0, "tl.LowerL2Persistent", {});
} }
TVM_REGISTER_GLOBAL("tl.transform.LowerL2Persistent") TVM_FFI_STATIC_INIT_BLOCK({
.set_body_typed(LowerL2Persistent); namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerL2Persistent", LowerL2Persistent);
});
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file lower_opaque_block.cc
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "tir/transforms/ir_utils.h"
namespace tvm {
namespace tl {
using namespace tir;
using namespace tir::attr;
/*!
* \brief Remove Block to ensure that the TIR can not be scheduled again.
*/
class OpaqueBlockLower : public StmtExprMutator {
public:
static Stmt Rewrite(Stmt body) {
OpaqueBlockLower lower;
lower.storage_align_ = CollectStorageAlignAnnotation(body);
return lower(std::move(body));
}
private:
Stmt VisitStmt_(const BlockRealizeNode *op) final {
// We have convert blocks into opaque blocks in previous passes.
ICHECK(op->iter_values.empty())
<< "Non-opaque blocks are not allowed in FlattenBuffer. Please "
"call pass ConvertBlocksToOpaque before.";
// Step 1. Visit the body
Block new_block = Downcast<Block>(this->VisitStmt(op->block));
PrimExpr predicate = this->VisitExpr(op->predicate);
// Step 2. Transform the `predicate` to if-then-else
Stmt body = new_block->body;
if (!is_one(predicate)) {
body = IfThenElse(predicate, std::move(body));
}
// Step 3. Handle allocations in reverse order
for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) {
const Buffer &buffer = new_block->alloc_buffers[i - 1];
Array<PrimExpr> allocation_shape = GetBufferAllocationShape(buffer);
body = DeclBuffer(buffer, std::move(body));
Map<String, ffi::Any> allocate_annotations;
auto it = storage_align_.find(buffer->data);
if (it != storage_align_.end()) {
StorageAlignAnnotation allocate_aligns;
for (auto tuple : it->second) {
tuple.Set<0>(-1);
allocate_aligns.push_back(tuple);
}
allocate_annotations.Set(tir::attr::buffer_dim_align, allocate_aligns);
}
body = Allocate(buffer->data, buffer->dtype, allocation_shape,
const_true(), std::move(body), allocate_annotations);
}
// Step 4. Handle annotations, block annotations are not preserved by
// default.
std::vector<std::pair<std::string, PrimExpr>> pragma_attrs;
HandleAnnotations(new_block->annotations, &pragma_attrs, /*is_block=*/true);
for (auto it = pragma_attrs.rbegin(); it != pragma_attrs.rend(); ++it) {
body = AttrStmt(Integer(0), it->first, it->second, std::move(body));
}
return body;
}
Stmt VisitStmt_(const BlockNode *op) final {
Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
if (block->annotations.count("stmt_group")) {
return block->body;
}
return block;
}
Stmt VisitStmt_(const ForNode *op) final {
// Step 1. Update unit loop info.
PrimExpr min = this->VisitExpr(op->min);
PrimExpr extent = this->VisitExpr(op->extent);
if (is_one(extent) && op->annotations.empty()) {
// handling unit loop
unit_loop_vars_[op->loop_var] = min;
}
// Step 2. Visit recursively
Stmt body = this->VisitStmt(op->body);
// Step 3. Handle annotations
std::vector<std::pair<std::string, PrimExpr>> pragma_attrs;
Map<String, ffi::Any> new_annotations =
HandleAnnotations(op->annotations, &pragma_attrs, /*is_block=*/false);
// Step 4. Create new For loop accordingly
if (op->kind == ForKind::kThreadBinding) {
// Case 1. Thread binding
ICHECK(op->thread_binding.defined());
String thread_tag = op->thread_binding.value()->thread_tag;
body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body);
} else if (is_one(extent) && op->annotations.empty()) {
// Case 2. Unit loop
return body;
} else {
// Case 3. An ordinary loop
body = For(op->loop_var, std::move(min), std::move(extent), op->kind,
std::move(body), std::nullopt, new_annotations);
}
// Step 5. Insert nested attrs
for (auto it = pragma_attrs.rbegin(); it != pragma_attrs.rend(); ++it) {
body = AttrStmt(op->loop_var, it->first, it->second, std::move(body));
}
return body;
}
PrimExpr VisitExpr_(const VarNode *op) final {
Var var = GetRef<Var>(op);
auto it = unit_loop_vars_.find(var);
if (it == unit_loop_vars_.end()) {
return var;
} else {
PrimExpr expr = it->second;
if (expr.dtype() != var.dtype()) {
expr = tvm::cast(var.dtype(), std::move(expr));
}
return expr;
}
}
static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var,
String thread_tag, Stmt body) {
IterVar iter_var(/*dom=*/Range::FromMinExtent(min, extent),
/*var=*/std::move(var),
/*iter_type=*/IterVarType::kThreadIndex,
/*thread_tag=*/thread_tag);
String attr_key = (thread_tag == "vthread" || thread_tag == "vthread.x" ||
thread_tag == "vthread.y" || thread_tag == "vthread.z")
? tir::attr::virtual_thread
: tir::attr::thread_extent;
return AttrStmt(/*node=*/std::move(iter_var),
/*attr_key=*/std::move(attr_key),
/*value=*/std::move(extent),
/*body=*/std::move(body));
}
/*! \brief Convert attr value from annotation map into PrimExpr. */
PrimExpr ConvertAttrValue(const String &key, const Any &obj) {
if (obj == nullptr) {
return PrimExpr();
} else if (auto expr = obj.try_cast<PrimExpr>()) {
return expr.value();
} else if (auto str = obj.try_cast<String>()) {
return std::move(StringImm(str.value()));
} else {
LOG(FATAL) << "Illegal attribute of key " << key << ", value type "
<< obj.GetTypeKey() << " not supported";
return PrimExpr();
}
}
/*!
* \brief Helper to handle annotation dict.
* (1) if the attr key is prefixed by `pragma_`, move to ordered kv list. They
* are lowered to `AttrStmt` by legacy TE schedule convention.
* (2) the non-pragma loop annotations are preserved
* (3) the non-pragma block annotations are dropped
* \return New annotation dict with preserved keys. Also update pragma attr
* pairs ordered by key.
*/
Map<String, ffi::Any>
HandleAnnotations(const Map<String, ffi::Any> &annotations,
std::vector<std::pair<std::string, PrimExpr>> *pragma_attrs,
bool is_block) {
Map<String, ffi::Any> preserved_annotations;
pragma_attrs->clear();
for (const auto &kv : annotations) {
const String &key = kv.first;
if (tir::attr::IsPragmaKey(key)) {
pragma_attrs->emplace_back(key, ConvertAttrValue(key, kv.second));
} else if (!is_block) {
// the loop annotation is preserved
preserved_annotations.Set(key, kv.second);
}
}
std::sort(
pragma_attrs->begin(), pragma_attrs->end(),
[](const auto &p1, const auto &p2) { return p1.first < p2.first; });
return preserved_annotations;
}
/*! \brief Record the loop_var and loop start value of unit loops, whose
* extent is one. */
std::unordered_map<Var, PrimExpr> unit_loop_vars_;
/*! \brief Attr keys to preserve into loop annotations. */
std::unordered_set<std::string> preserved_annotations_;
/*! \brief The map from buffer var to its storage alignment information. */
std::unordered_map<Var, StorageAlignAnnotation> storage_align_;
};
PrimFunc TLLowerOpaqueBlock(PrimFunc f) {
auto fptr = f.CopyOnWrite();
fptr->body = OpaqueBlockLower::Rewrite(std::move(fptr->body));
return f;
}
tir::transform::Pass LowerOpaqueBlock() {
using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return TLLowerOpaqueBlock(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tl.LowerOpaqueBlock", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerOpaqueBlock", LowerOpaqueBlock);
});
} // namespace tl
} // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment