swizzle.h 1.91 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

/*!
 * \file swizzle.h
 * \brief Define swizzled layout
 *
 */

#ifndef TVM_TL_LAYOUT_SWIZZLE_H_
#define TVM_TL_LAYOUT_SWIZZLE_H_

#include "layout.h"

namespace tvm {
namespace tl {

/*!
 * \brief Swizzle pattern
 */
class SwizzlePattern {
22
public:
23
24
25
26
27
28
  SwizzlePattern() = default;
  SwizzlePattern(int bits, int base, int shift);
  PrimExpr swizzle(PrimExpr expr) const;
  int Bits() const { return bits_; }
  int Base() const { return base_; }
  int Shift() const { return shift_; }
29
  bool operator==(const SwizzlePattern &other) const;
30

31
private:
32
33
34
35
36
37
38
39
40
  int bits_;
  int base_;
  int shift_;
};

/*!
 * \brief Layout with swizzle
 */
class SwizzledLayoutNode : public LayoutNode {
41
public:
42
43
44
45
  SwizzledLayoutNode() = default;
  SwizzledLayoutNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
                     SwizzlePattern pattern);

46
  Array<PrimExpr> Forward(const Array<PrimExpr> &vars) const final;
47
  Layout Inverse() const final;
48
49
  std::string DebugOutput() const final;
  bool IsEqual(const SwizzledLayoutNode *other, bool skip_index = false) const;
50
51
52
  static constexpr const char *_type_key = "tl.SwizzledLayout";
  bool SEqualReduce(const SwizzledLayoutNode *other, SEqualReducer equal) const;
  void VisitAttrs(tvm::AttrVisitor *v);
53
54
  TVM_DECLARE_FINAL_OBJECT_INFO(SwizzledLayoutNode, LayoutNode);

55
private:
56
57
58
59
60
61
62
  SwizzlePattern pattern_;
};

/*!
 * \brief SwizzledLayout reference class.
 */
class SwizzledLayout : public Layout {
63
64
65
66
67
public:
  TVM_DLL SwizzledLayout(Array<IterVar> forward_var,
                         Array<PrimExpr> forward_index, SwizzlePattern pattern);
  TVM_DLL SwizzledLayout(Array<PrimExpr> input_size,
                         Array<PrimExpr> forward_index, SwizzlePattern pattern);
68
69
70
71

  TVM_DEFINE_OBJECT_REF_METHODS(SwizzledLayout, Layout, SwizzledLayoutNode);
};

72
73
} // namespace tl
} // namespace tvm
74

75
#endif // TVM_TL_LAYOUT_SWIZZLE_H_