swizzle.h 1.83 KB
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
/*!
 * \file swizzle.h
 * \brief Define swizzled layout
 *
 */

#ifndef TVM_TL_LAYOUT_SWIZZLE_H_
#define TVM_TL_LAYOUT_SWIZZLE_H_

#include "layout.h"

namespace tvm {
namespace tl {

/*!
 * \brief Swizzle pattern
 */
class SwizzlePattern {
public:
  SwizzlePattern() = default;
  SwizzlePattern(int bits, int base, int shift);
  PrimExpr swizzle(PrimExpr expr) const;
  int Bits() const { return bits_; }
  int Base() const { return base_; }
  int Shift() const { return shift_; }
  bool operator==(const SwizzlePattern &other) const;

private:
  int bits_;
  int base_;
  int shift_;
};

/*!
 * \brief Layout with swizzle
 */
class SwizzledLayoutNode : public LayoutNode {
public:
  SwizzledLayoutNode() = default;
  SwizzledLayoutNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
                     SwizzlePattern pattern);

  Array<PrimExpr> Forward(const Array<PrimExpr> &vars) const final;
  Layout Inverse() const final;
  std::string DebugOutput() const final;
  bool IsEqual(const SwizzledLayoutNode *other, bool skip_index = false) const;
  static constexpr const char *_type_key = "tl.SwizzledLayout";
  bool SEqualReduce(const SwizzledLayoutNode *other, SEqualReducer equal) const;
  static void RegisterReflection();
  TVM_DECLARE_FINAL_OBJECT_INFO(SwizzledLayoutNode, LayoutNode);

private:
  SwizzlePattern pattern_;
};

/*!
 * \brief SwizzledLayout reference class.
 */
class SwizzledLayout : public Layout {
public:
  TVM_DLL SwizzledLayout(Array<IterVar> forward_var,
                         Array<PrimExpr> forward_index, SwizzlePattern pattern);
  TVM_DLL SwizzledLayout(Array<PrimExpr> input_size,
                         Array<PrimExpr> forward_index, SwizzlePattern pattern);

  TVM_DEFINE_OBJECT_REF_METHODS(SwizzledLayout, Layout, SwizzledLayoutNode);
};

} // namespace tl
} // namespace tvm

#endif // TVM_TL_LAYOUT_SWIZZLE_H_