Unverified Commit e7e4e65b authored by Kuris's avatar Kuris Committed by GitHub
Browse files

[Enhancement] Add debug output methods for Layout and Fragment classes (#1392)

parent 242b43bb
......@@ -12,6 +12,8 @@
#include <tvm/tir/stmt_functor.h>
#include "arith/pattern_match.h"
#include "tvm/node/functor.h"
#include "tvm/node/repr_printer.h"
#include "utils.h"
namespace tvm {
......@@ -78,7 +80,8 @@ void LayoutNode::RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<LayoutNode>()
.def_ro("input_size", &LayoutNode::input_size_)
.def_ro("forward_index", &LayoutNode::forward_index_);
.def_ro("forward_index", &LayoutNode::forward_index_)
.def("_DebugOutput", &LayoutNode::DebugOutput);
}
void LayoutNode::UpdateAnalyzer(arith::Analyzer *analyzer) const {
......@@ -716,8 +719,19 @@ void FragmentNode::RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<FragmentNode>()
.def_ro("forward_thread", &FragmentNode::forward_thread_)
.def_ro("replicate_size", &FragmentNode::replicate_size_);
}
.def_ro("replicate_size", &FragmentNode::replicate_size_)
.def("_DebugOutput", &FragmentNode::DebugOutput);
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FragmentNode>([](const ObjectRef &obj, ReprPrinter *p) {
auto *node = static_cast<const FragmentNode *>(obj.get());
p->stream << node->DebugOutput();
})
.set_dispatch<LayoutNode>([](const ObjectRef &obj, ReprPrinter *p) {
auto *node = static_cast<const LayoutNode *>(obj.get());
p->stream << node->DebugOutput();
});
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
......
......@@ -203,7 +203,8 @@ class Fragment(Layout):
str
A string showing the thread dimension and the index dimension.
"""
return f"Fragment<{self.get_input_shape()}->{self.get_output_shape()}, thread={self.thread}, index={self.index}>"
return self._DebugOutput()
# return f"Fragment<{self.get_input_shape()}->{self.get_output_shape()}, thread={self.thread}, index={self.index}>"
def is_equal(self, other: 'Fragment') -> bool:
"""
......
......@@ -143,4 +143,5 @@ class Layout(Node):
return _ffi_api.Layout_is_equal(self, other)
def __repr__(self):
return f"Layout<{self.get_input_shape()}->{self.get_output_shape()}, {self.get_forward_vars()} -> {self.get_forward_index()}>"
return self._DebugOutput()
# return f"Layout<{self.get_input_shape()}->{self.get_output_shape()}, {self.get_forward_vars()} -> {self.get_forward_index()}>"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment