"docs/source/api/vscode:/vscode.git/clone" did not exist on "cf5265ad413e200c697ee5e13bd326fc82f11c6c"
Unverified Commit e7e4e65b authored by Kuris's avatar Kuris Committed by GitHub
Browse files

[Enhancement] Add debug output methods for Layout and Fragment classes (#1392)

parent 242b43bb
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include "arith/pattern_match.h" #include "arith/pattern_match.h"
#include "tvm/node/functor.h"
#include "tvm/node/repr_printer.h"
#include "utils.h" #include "utils.h"
namespace tvm { namespace tvm {
...@@ -78,7 +80,8 @@ void LayoutNode::RegisterReflection() { ...@@ -78,7 +80,8 @@ void LayoutNode::RegisterReflection() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::ObjectDef<LayoutNode>() refl::ObjectDef<LayoutNode>()
.def_ro("input_size", &LayoutNode::input_size_) .def_ro("input_size", &LayoutNode::input_size_)
.def_ro("forward_index", &LayoutNode::forward_index_); .def_ro("forward_index", &LayoutNode::forward_index_)
.def("_DebugOutput", &LayoutNode::DebugOutput);
} }
void LayoutNode::UpdateAnalyzer(arith::Analyzer *analyzer) const { void LayoutNode::UpdateAnalyzer(arith::Analyzer *analyzer) const {
...@@ -716,8 +719,19 @@ void FragmentNode::RegisterReflection() { ...@@ -716,8 +719,19 @@ void FragmentNode::RegisterReflection() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::ObjectDef<FragmentNode>() refl::ObjectDef<FragmentNode>()
.def_ro("forward_thread", &FragmentNode::forward_thread_) .def_ro("forward_thread", &FragmentNode::forward_thread_)
.def_ro("replicate_size", &FragmentNode::replicate_size_); .def_ro("replicate_size", &FragmentNode::replicate_size_)
} .def("_DebugOutput", &FragmentNode::DebugOutput);
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FragmentNode>([](const ObjectRef &obj, ReprPrinter *p) {
auto *node = static_cast<const FragmentNode *>(obj.get());
p->stream << node->DebugOutput();
})
.set_dispatch<LayoutNode>([](const ObjectRef &obj, ReprPrinter *p) {
auto *node = static_cast<const LayoutNode *>(obj.get());
p->stream << node->DebugOutput();
});
TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
......
...@@ -203,7 +203,8 @@ class Fragment(Layout): ...@@ -203,7 +203,8 @@ class Fragment(Layout):
str str
A string showing the thread dimension and the index dimension. A string showing the thread dimension and the index dimension.
""" """
return f"Fragment<{self.get_input_shape()}->{self.get_output_shape()}, thread={self.thread}, index={self.index}>" return self._DebugOutput()
# return f"Fragment<{self.get_input_shape()}->{self.get_output_shape()}, thread={self.thread}, index={self.index}>"
def is_equal(self, other: 'Fragment') -> bool: def is_equal(self, other: 'Fragment') -> bool:
""" """
......
...@@ -143,4 +143,5 @@ class Layout(Node): ...@@ -143,4 +143,5 @@ class Layout(Node):
return _ffi_api.Layout_is_equal(self, other) return _ffi_api.Layout_is_equal(self, other)
def __repr__(self): def __repr__(self):
return f"Layout<{self.get_input_shape()}->{self.get_output_shape()}, {self.get_forward_vars()} -> {self.get_forward_index()}>" return self._DebugOutput()
# return f"Layout<{self.get_input_shape()}->{self.get_output_shape()}, {self.get_forward_vars()} -> {self.get_forward_index()}>"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment