Commit fa511857 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Overall Typo and Linting Fixes (#13)

* README.md fixed

* update test ci

* Lint and Typo Fix

* Clang Format Lint Fix
parent be55163f
......@@ -13,78 +13,94 @@ using cutlass::gemm::GemmShape;
// Add 128 bits padding when the last dim is a multiple of 256 bits
template <typename T, bool transpose, int M, int K, typename Enable = void>
struct DispatchSharedMemoryLayoutA {
using Layout = typename std::conditional<transpose, cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>::type;
using Layout =
typename std::conditional<transpose, cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>::type;
static int constexpr Dim = transpose ? M : K;
static int constexpr Stride = (Dim * sizeof(T) % 32 == 0) ? Dim + 16 / sizeof(T) : Dim;
static int constexpr Stride =
(Dim * sizeof(T) % 32 == 0) ? Dim + 16 / sizeof(T) : Dim;
};
template <typename T, bool transpose, int N, int K, typename Enable = void>
struct DispatchSharedMemoryLayoutB {
using Layout = typename std::conditional<transpose, cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>::type;
using Layout =
typename std::conditional<transpose, cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>::type;
static int constexpr Dim = transpose ? K : N;
static int constexpr Stride = (Dim * sizeof(T) % 32 == 0) ? Dim + 16 / sizeof(T) : Dim;
static int constexpr Stride =
(Dim * sizeof(T) % 32 == 0) ? Dim + 16 / sizeof(T) : Dim;
};
// Partial specialization for half_t
template <int M, int K>
struct DispatchSharedMemoryLayoutA<half_t, true, M, K, typename std::enable_if<M % 64 == 0>::type> {
using Layout = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous<16>;
struct DispatchSharedMemoryLayoutA<half_t, true, M, K,
typename std::enable_if<M % 64 == 0>::type> {
using Layout =
cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous<16>;
static int constexpr Stride = M;
};
template <int M, int K>
struct DispatchSharedMemoryLayoutA<half_t, false, M, K> {
using Layout = cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, K>;
using Layout =
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, K>;
static int constexpr Stride = M;
};
template <int N, int K>
struct DispatchSharedMemoryLayoutB<half_t, true, N, K> {
using Layout = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise<16, K>;
template <int N, int K> struct DispatchSharedMemoryLayoutB<half_t, true, N, K> {
using Layout =
cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise<16, K>;
static int constexpr Stride = N;
};
template <int N, int K>
struct DispatchSharedMemoryLayoutB<half_t, false, N, K,
typename std::enable_if<N % 64 == 0>::type> {
using Layout = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous<16>;
using Layout =
cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous<16>;
static int constexpr Stride = N;
};
template <typename Shape, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B,
typename A_type_raw, typename B_type_raw, typename C_type_raw>
template <typename Shape, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type_raw, typename B_type_raw,
typename C_type_raw>
class GemmTensorOp {
public:
public:
using A_type = A_type_raw;
using B_type = B_type_raw;
using C_type = C_type_raw;
using InstructionShape = GemmShape<16, 16, 4>;
using SMemLayoutA =
typename DispatchSharedMemoryLayoutA<A_type, trans_A, Shape::kM, Shape::kK>::Layout;
typename DispatchSharedMemoryLayoutA<A_type, trans_A, Shape::kM,
Shape::kK>::Layout;
using SMemLayoutB =
typename DispatchSharedMemoryLayoutB<B_type, trans_B, Shape::kN, Shape::kK>::Layout;
typename DispatchSharedMemoryLayoutB<B_type, trans_B, Shape::kN,
Shape::kK>::Layout;
static constexpr int stride_A =
DispatchSharedMemoryLayoutA<A_type, trans_A, Shape::kM, Shape::kK>::Stride;
DispatchSharedMemoryLayoutA<A_type, trans_A, Shape::kM,
Shape::kK>::Stride;
static constexpr int stride_B =
DispatchSharedMemoryLayoutB<B_type, trans_B, Shape::kN, Shape::kK>::Stride;
DispatchSharedMemoryLayoutB<B_type, trans_B, Shape::kN,
Shape::kK>::Stride;
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
cutlass::arch::Mma<InstructionShape, 32, A_type,
typename std::conditional<trans_A, cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>::type,
B_type,
typename std::conditional<trans_B, cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>::type,
C_type, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>,
cutlass::MatrixShape<1, 1> >;
cutlass::arch::Mma<
InstructionShape, 32, A_type,
typename std::conditional<trans_A, cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>::type,
B_type,
typename std::conditional<trans_B, cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>::type,
C_type, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>,
cutlass::MatrixShape<1, 1>>;
static_assert(Shape::kM % num_warp_m == 0);
static_assert(Shape::kN % num_warp_n == 0);
using MmaWarp = typename cutlass::gemm::warp::MmaVoltaTensorOp<
GemmShape<Shape::kM / num_warp_m, Shape::kN / num_warp_n, InstructionShape::kK>, A_type,
SMemLayoutA, B_type, SMemLayoutB, C_type, cutlass::layout::RowMajor, Policy>;
GemmShape<Shape::kM / num_warp_m, Shape::kN / num_warp_n,
InstructionShape::kK>,
A_type, SMemLayoutA, B_type, SMemLayoutB, C_type,
cutlass::layout::RowMajor, Policy>;
using TensorRefA = typename MmaWarp::IteratorA::TensorRef;
using TensorRefB = typename MmaWarp::IteratorB::TensorRef;
......@@ -97,13 +113,14 @@ class GemmTensorOp {
static_assert(Shape::kK % InstructionShape::kK == 0);
static int constexpr kKgroups = Shape::kK / InstructionShape::kK;
static CUTLASS_DEVICE void body(A_type_raw* pA, B_type_raw* pB, FragmentC& accum,
const int warp_idx_m, const int warp_idx_n, const int lane_id) {
static CUTLASS_DEVICE void body(A_type_raw *pA, B_type_raw *pB,
FragmentC &accum, const int warp_idx_m,
const int warp_idx_n, const int lane_id) {
MmaWarp mma_op;
FragmentA frag_A;
FragmentB frag_B;
const TensorRefA ref_A((A_type*)pA, stride_A);
const TensorRefB ref_B((B_type*)pB, stride_B);
const TensorRefA ref_A((A_type *)pA, stride_A);
const TensorRefB ref_B((B_type *)pB, stride_B);
IteratorA iter_A(ref_A, lane_id);
IteratorB iter_B(ref_B, lane_id);
iter_A.add_tile_offset({warp_idx_m, 0});
......@@ -118,11 +135,12 @@ class GemmTensorOp {
}
}
static CUTLASS_DEVICE void body_rs(const FragmentA* frag_A, B_type_raw* pB, FragmentC& accum,
const int warp_idx_n, const int lane_id) {
static CUTLASS_DEVICE void body_rs(const FragmentA *frag_A, B_type_raw *pB,
FragmentC &accum, const int warp_idx_n,
const int lane_id) {
MmaWarp mma_op;
FragmentB frag_B;
const TensorRefB ref_B((B_type*)pB, stride_B);
const TensorRefB ref_B((B_type *)pB, stride_B);
IteratorB iter_B(ref_B, lane_id);
iter_B.add_tile_offset({0, warp_idx_n});
CUTLASS_PRAGMA_UNROLL
......@@ -136,27 +154,29 @@ class GemmTensorOp {
namespace tl {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B,
typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_ss(A_type* pA, B_type* pB, C_type* accum) {
using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n, trans_A, trans_B, A_type,
B_type, C_type>;
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>;
using FragmentC = typename MMA::FragmentC;
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
MMA::body(pA, pB, *(FragmentC*)(accum), warp_id / num_warp_n, warp_id % num_warp_n, lane_id);
MMA::body(pA, pB, *(FragmentC *)(accum), warp_id / num_warp_n,
warp_id % num_warp_n, lane_id);
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B,
typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_rs(A_type* pA, B_type* pB, C_type* accum) {
using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n, trans_A, trans_B, A_type,
B_type, C_type>;
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>;
using FragmentA = typename MMA::FragmentA;
using FragmentC = typename MMA::FragmentC;
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
MMA::body_rs((const FragmentA*)(pA), pB, *(FragmentC*)(accum), warp_id % num_warp_n, lane_id);
MMA::body_rs((const FragmentA *)(pA), pB, *(FragmentC *)(accum),
warp_id % num_warp_n, lane_id);
}
}; // namespace tl
}; // namespace tl
......@@ -12,39 +12,32 @@ template <typename A_type, typename B_type, typename C_type>
struct DispatchInstruction;
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800))
template <>
struct DispatchInstruction<half_t, half_t, half_t> {
template <> struct DispatchInstruction<half_t, half_t, half_t> {
using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
using MMA_Group = Layout<Shape<_1, _2, _1>>;
};
template <>
struct DispatchInstruction<half_t, half_t, float> {
template <> struct DispatchInstruction<half_t, half_t, float> {
using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>;
using MMA_Group = Layout<Shape<_1, _2, _1>>;
};
template <>
struct DispatchInstruction<bfloat16_t, bfloat16_t, float> {
template <> struct DispatchInstruction<bfloat16_t, bfloat16_t, float> {
using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>;
using MMA_Group = Layout<Shape<_1, _2, _1>>;
};
template <>
struct DispatchInstruction<tfloat32_t, tfloat32_t, float> {
template <> struct DispatchInstruction<tfloat32_t, tfloat32_t, float> {
using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>;
using MMA_Group = Layout<Shape<_1, _2, _1>>;
};
template <>
struct DispatchInstruction<int8_t, int8_t, int> {
template <> struct DispatchInstruction<int8_t, int8_t, int> {
using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>;
using MMA_Group = Layout<Shape<_1, _2, _1>>;
};
template <>
struct DispatchInstruction<double, double, double> {
template <> struct DispatchInstruction<double, double, double> {
using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>;
using MMA_Group = Layout<Shape<_2, _2, _1>>;
};
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
template <>
struct DispatchInstruction<half_t, half_t, float> {
template <> struct DispatchInstruction<half_t, half_t, float> {
using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
using MMA_Group = Layout<Shape<_1, _2, _2>>;
};
......@@ -54,149 +47,175 @@ template <int Bits, int N, int K, bool K_inner, typename Enable = void>
struct OperandTraits {
// Primary template, use padded layout and default copy
static constexpr int stride = K_inner ? K : N;
static constexpr int padded = stride % (256 / Bits) == 0 ? stride + 128 / Bits : stride;
using Layout =
typename std::conditional<K_inner, Layout<Shape<Int<N>, Int<K>>, Shape<Int<padded>, _1>>,
Layout<Shape<Int<N>, Int<K>>, Shape<_1, Int<padded>>>>::type;
static constexpr int padded =
stride % (256 / Bits) == 0 ? stride + 128 / Bits : stride;
using Layout = typename std::conditional<
K_inner, Layout<Shape<Int<N>, Int<K>>, Shape<Int<padded>, _1>>,
Layout<Shape<Int<N>, Int<K>>, Shape<_1, Int<padded>>>>::type;
using Copy = DefaultCopy;
};
template <int N, int K>
struct OperandTraits<16, N, K, true, typename std::enable_if<K % 64 == 32>::type> {
using LayoutAtom =
decltype(composition(Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
struct OperandTraits<16, N, K, true,
typename std::enable_if<K % 64 == 32>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
};
template <int N, int K>
struct OperandTraits<16, N, K, true, typename std::enable_if<K % 64 == 0>::type> {
using LayoutAtom =
decltype(composition(Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
struct OperandTraits<16, N, K, true,
typename std::enable_if<K % 64 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
};
template <int N, int K>
struct OperandTraits<16, N, K, false, typename std::enable_if<N % 64 == 32>::type> {
using LayoutAtom =
decltype(composition(Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}, Step<_2, _1>{}));
struct OperandTraits<16, N, K, false,
typename std::enable_if<N % 64 == 32>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = SM75_U16x8_LDSM_T;
};
template <int N, int K>
struct OperandTraits<16, N, K, false, typename std::enable_if<N % 64 == 0>::type> {
using LayoutAtom =
decltype(composition(Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}, Step<_2, _1>{}));
struct OperandTraits<16, N, K, false,
typename std::enable_if<N % 64 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = SM75_U16x8_LDSM_T;
};
template <int N, int K>
struct OperandTraits<32, N, K, true, typename std::enable_if<K % 32 == 0>::type> {
using LayoutAtom =
decltype(composition(Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
struct OperandTraits<32, N, K, true,
typename std::enable_if<K % 32 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
};
template <int N, int K>
struct OperandTraits<32, N, K, true, typename std::enable_if<K % 32 == 16>::type> {
using LayoutAtom =
decltype(composition(Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{}));
struct OperandTraits<32, N, K, true,
typename std::enable_if<K % 32 == 16>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
};
template <int N, int K>
struct OperandTraits<32, N, K, false, typename std::enable_if<N % 32 == 0>::type> {
using LayoutAtom =
decltype(composition(Swizzle<3, 2, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}, Step<_2, _1>{}));
struct OperandTraits<32, N, K, false,
typename std::enable_if<N % 32 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 2, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = UniversalCopy<tfloat32_t>;
};
template <int N, int K>
struct OperandTraits<32, N, K, false, typename std::enable_if<N % 32 == 16>::type> {
using LayoutAtom =
decltype(composition(Swizzle<2, 2, 3>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}, Step<_2, _1>{}));
struct OperandTraits<32, N, K, false,
typename std::enable_if<N % 32 == 16>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 3>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = UniversalCopy<tfloat32_t>;
};
template <int N, int K>
struct OperandTraits<8, N, K, true, typename std::enable_if<K % 128 == 64>::type> {
using LayoutAtom =
decltype(composition(Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
struct OperandTraits<8, N, K, true,
typename std::enable_if<K % 128 == 64>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
};
template <int N, int K>
struct OperandTraits<8, N, K, true, typename std::enable_if<K % 128 == 0>::type> {
using LayoutAtom =
decltype(composition(Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{}));
struct OperandTraits<8, N, K, true,
typename std::enable_if<K % 128 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
};
template <int N, int K>
struct OperandTraits<64, N, K, true, typename std::enable_if<K % 16 == 0>::type> {
using LayoutAtom =
decltype(composition(Swizzle<2, 0, 4>{}, Layout<Shape<_4, _16>, Stride<_16, _1>>{}));
struct OperandTraits<64, N, K, true,
typename std::enable_if<K % 16 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 0, 4>{}, Layout<Shape<_4, _16>, Stride<_16, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = DefaultCopy;
};
template <int N, int K>
struct OperandTraits<64, N, K, false, typename std::enable_if<N % 16 == 0>::type> {
using LayoutAtom =
decltype(composition(Swizzle<2, 2, 2>{}, Layout<Shape<_16, _4>, Stride<_1, _16>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}, Step<_2, _1>{}));
struct OperandTraits<64, N, K, false,
typename std::enable_if<N % 16 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 2>{}, Layout<Shape<_16, _4>, Stride<_1, _16>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = DefaultCopy;
};
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B,
typename A_type_raw, typename B_type_raw, typename C_type_raw>
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type_raw, typename B_type_raw,
typename C_type_raw>
class GemmTensorOp {
public:
using A_type = typename std::conditional<std::is_same<A_type_raw, float>::value, tfloat32_t,
A_type_raw>::type;
using B_type = typename std::conditional<std::is_same<B_type_raw, float>::value, tfloat32_t,
A_type_raw>::type;
public:
using A_type =
typename std::conditional<std::is_same<A_type_raw, float>::value,
tfloat32_t, A_type_raw>::type;
using B_type =
typename std::conditional<std::is_same<B_type_raw, float>::value,
tfloat32_t, A_type_raw>::type;
using C_type = C_type_raw;
using Instruction = DispatchInstruction<A_type, B_type, C_type>;
using OperandATraits = OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A>;
using OperandBTraits = OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B>;
using OperandATraits =
OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A>;
using OperandBTraits =
OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B>;
using SmemLayoutA = typename OperandATraits::Layout;
using SmemLayoutB = typename OperandBTraits::Layout;
using SmemCopyA = Copy_Atom<typename OperandATraits::Copy, A_type>;
using SmemCopyB = Copy_Atom<typename OperandBTraits::Copy, B_type>;
using TileMma =
TiledMMA<typename Instruction::MMA, Layout<Shape<Int<num_warp_m>, Int<num_warp_n>, _1>>,
typename Instruction::MMA_Group>;
using TileMma = TiledMMA<typename Instruction::MMA,
Layout<Shape<Int<num_warp_m>, Int<num_warp_n>, _1>>,
typename Instruction::MMA_Group>;
template <class... Args>
static CUTE_DEVICE auto remove_swizzle(Layout<Args...> const& layout) {
static CUTE_DEVICE auto remove_swizzle(Layout<Args...> const &layout) {
return layout;
}
// In fp16, when layout is KxN and n_warp is 1 and N % 64 == 0
// the original layout fail to compile, currently using this as a workaround
template <class... Args>
static CUTE_DEVICE auto remove_swizzle(ComposedLayout<Args...> const& layout) {
static CUTE_DEVICE auto
remove_swizzle(ComposedLayout<Args...> const &layout) {
if constexpr (sizeof(A_type) == 2)
return layout.layout_b();
else
return layout;
}
static CUTE_DEVICE void body(A_type_raw* pA, B_type_raw* pB, C_type_raw* pC) {
static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) {
const int tid = threadIdx.x;
Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type*>(pA)), SmemLayoutA{});
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type*>(pB)), SmemLayoutB{});
Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
SmemLayoutA{});
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
SmemLayoutB{});
TileMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tid);
auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
......@@ -212,10 +231,12 @@ class GemmTensorOp {
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);
Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast<C_type*>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
Tensor acc =
make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
// when layout is KxN and n_warp is 1, there seem to be a bug, use this as a workaround
// when layout is KxN and n_warp is 1, there seem to be a bug, use this as a
// workaround
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
CUTE_UNROLL
......@@ -226,9 +247,11 @@ class GemmTensorOp {
}
}
static CUTE_DEVICE void body_rs(A_type_raw* pA, B_type_raw* pB, C_type_raw* pC) {
static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB,
C_type_raw *pC) {
const int tid = threadIdx.x;
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type*>(pB)), SmemLayoutB{});
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
SmemLayoutB{});
TileMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tid);
auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma);
......@@ -239,10 +262,12 @@ class GemmTensorOp {
Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);
Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast<C_type*>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
Tensor tCrA = make_tensor(make_rmem_ptr(reinterpret_cast<A_type*>(pA)),
partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
Tensor acc =
make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
Tensor tCrA =
make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0));
......@@ -255,9 +280,11 @@ class GemmTensorOp {
}
}
static CUTE_DEVICE void body_sr(A_type_raw* pA, B_type_raw* pB, C_type_raw* pC) {
static CUTE_DEVICE void body_sr(A_type_raw *pA, B_type_raw *pB,
C_type_raw *pC) {
const int tid = threadIdx.x;
Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type*>(pA)), SmemLayoutA{});
Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
SmemLayoutA{});
TileMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tid);
auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
......@@ -268,10 +295,12 @@ class GemmTensorOp {
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast<C_type*>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
Tensor tCrB = make_tensor(make_rmem_ptr(reinterpret_cast<B_type*>(pB)),
partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));
Tensor acc =
make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
Tensor tCrB =
make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)),
partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0));
......@@ -285,32 +314,32 @@ class GemmTensorOp {
}
};
} // namespace cute
} // namespace cute
namespace tl {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B,
typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_ss(A_type* pA, B_type* pB, C_type* accum) {
using MMA =
cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, A_type, B_type, C_type>;
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>;
MMA::body(pA, pB, accum);
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B,
typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_rs(A_type* pA, B_type* pB, C_type* accum) {
using MMA =
cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, A_type, B_type, C_type>;
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>;
MMA::body_rs(pA, pB, accum);
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B,
typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_sr(A_type* pA, B_type* pB, C_type* accum) {
using MMA =
cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, A_type, B_type, C_type>;
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>;
MMA::body_sr(pA, pB, accum);
}
} // namespace tl
} // namespace tl
......@@ -2,9 +2,9 @@
// Licensed under the MIT License.
#pragma once
#include <cutlass/cutlass.h>
#include <cutlass/arch/barrier.h>
#include <cute/algorithm/copy.hpp>
#include <cutlass/arch/barrier.h>
#include <cutlass/cutlass.h>
#include "common.h"
......@@ -19,78 +19,112 @@ CUTE_HOST_DEVICE constexpr auto ss_smem_selector() {
static_assert(BLK_K0 % 8 == 0, "BLK_K0 must be a multiple of 8.");
if constexpr (major == GMMA::Major::MN) {
if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom<ElementType>{}) == 0) {
if constexpr (BLK_MN0 %
size<0>(GMMA::Layout_MN_SW128_Atom<ElementType>{}) ==
0) {
return GMMA::Layout_MN_SW128_Atom<ElementType>{};
} else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW64_Atom<ElementType>{}) == 0) {
} else if constexpr (BLK_MN0 %
size<0>(
GMMA::Layout_MN_SW64_Atom<ElementType>{}) ==
0) {
return GMMA::Layout_MN_SW64_Atom<ElementType>{};
} else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom<ElementType>{}) == 0) {
} else if constexpr (BLK_MN0 %
size<0>(
GMMA::Layout_MN_SW32_Atom<ElementType>{}) ==
0) {
return GMMA::Layout_MN_SW32_Atom<ElementType>{};
} else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom<ElementType>{}) == 0) {
} else if constexpr (BLK_MN0 %
size<0>(
GMMA::Layout_MN_INTER_Atom<ElementType>{}) ==
0) {
return GMMA::Layout_MN_INTER_Atom<ElementType>{};
} else {
static_assert(
BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom<ElementType>{}) == 0,
"BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_INTER_Atom<ElementType>{})");
"BLK_MN0 must be a multiple of "
"size<0>(GMMA::Layout_MN_INTER_Atom<ElementType>{})");
}
} else if constexpr (major == GMMA::Major::K) {
if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW128_Atom<ElementType>{}) == 0) {
if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW128_Atom<ElementType>{}) ==
0) {
return GMMA::Layout_K_SW128_Atom<ElementType>{};
} else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW64_Atom<ElementType>{}) == 0) {
} else if constexpr (BLK_K0 %
size<1>(GMMA::Layout_K_SW64_Atom<ElementType>{}) ==
0) {
return GMMA::Layout_K_SW64_Atom<ElementType>{};
} else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW32_Atom<ElementType>{}) == 0) {
} else if constexpr (BLK_K0 %
size<1>(GMMA::Layout_K_SW32_Atom<ElementType>{}) ==
0) {
return GMMA::Layout_K_SW32_Atom<ElementType>{};
} else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom<ElementType>{}) == 0) {
} else if constexpr (BLK_K0 %
size<1>(
GMMA::Layout_K_INTER_Atom<ElementType>{}) ==
0) {
return GMMA::Layout_K_INTER_Atom<ElementType>{};
} else {
static_assert(
BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom<ElementType>{}) == 0,
"BLK_K0 must be a multiple of size<1>(GMMA::Layout_K_INTER_Atom<ElementType>{})");
"BLK_K0 must be a multiple of "
"size<1>(GMMA::Layout_K_INTER_Atom<ElementType>{})");
}
}
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B,
typename A_type_raw, typename B_type_raw, typename C_type_raw>
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type_raw, typename B_type_raw,
typename C_type_raw>
class GemmTensorOp {
public:
using A_type = conditional_t<std::is_same<A_type_raw, float>::value, tfloat32_t, A_type_raw>;
using B_type = conditional_t<std::is_same<B_type_raw, float>::value, tfloat32_t, B_type_raw>;
public:
using A_type = conditional_t<std::is_same<A_type_raw, float>::value,
tfloat32_t, A_type_raw>;
using B_type = conditional_t<std::is_same<B_type_raw, float>::value,
tfloat32_t, B_type_raw>;
using C_type = C_type_raw;
static constexpr GMMA::Major GmmaMajorA = trans_A ? GMMA::Major::MN : GMMA::Major::K;
static constexpr GMMA::Major GmmaMajorB = trans_B ? GMMA::Major::K : GMMA::Major::MN;
static constexpr GMMA::Major GmmaMajorA =
trans_A ? GMMA::Major::MN : GMMA::Major::K;
static constexpr GMMA::Major GmmaMajorB =
trans_B ? GMMA::Major::K : GMMA::Major::MN;
using SmemLayoutAtomA = decltype(ss_smem_selector<GmmaMajorA, A_type, Int<M>, Int<K>>());
using SmemLayoutAtomB = decltype(ss_smem_selector<GmmaMajorB, B_type, Int<N>, Int<K>>());
using SmemLayoutAtomA =
decltype(ss_smem_selector<GmmaMajorA, A_type, Int<M>, Int<K>>());
using SmemLayoutAtomB =
decltype(ss_smem_selector<GmmaMajorB, B_type, Int<N>, Int<K>>());
using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtomA{}, Shape<Int<M>, Int<K>>{},
conditional_t<trans_A, Step<_2, _1>, Step<_1, _2>>{}));
using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{}, Shape<Int<N>, Int<K>>{},
conditional_t<trans_B, Step<_1, _2>, Step<_2, _1>>{}));
using SmemLayoutA = decltype(tile_to_shape(
SmemLayoutAtomA{}, Shape<Int<M>, Int<K>>{},
conditional_t<trans_A, Step<_2, _1>, Step<_1, _2>>{}));
using SmemLayoutB = decltype(tile_to_shape(
SmemLayoutAtomB{}, Shape<Int<N>, Int<K>>{},
conditional_t<trans_B, Step<_1, _2>, Step<_2, _1>>{}));
// static_assert(num_warp_n == 1);
static_assert(num_warp_m % 4 == 0);
template <int wg_wait=0>
static CUTE_DEVICE void body(A_type_raw* pA, B_type_raw* pB, C_type_raw* pC) {
template <int wg_wait = 0>
static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) {
const int tid = threadIdx.x;
Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type*>(pA)), SmemLayoutA{});
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type*>(pB)), SmemLayoutB{});
auto tiled_mma =
make_tiled_mma(GMMA::ss_op_selector<A_type, B_type, C_type, Shape<Int<M>, Int<N / num_warp_n>, Int<K>>,
GmmaMajorA, GmmaMajorB>(),
Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{});
Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
SmemLayoutA{});
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
SmemLayoutB{});
auto tiled_mma = make_tiled_mma(
GMMA::ss_op_selector<A_type, B_type, C_type,
Shape<Int<M>, Int<N / num_warp_n>, Int<K>>,
GmmaMajorA, GmmaMajorB>(),
Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{});
auto thr_mma = tiled_mma.get_thread_slice(tid);
// Allocate registers for pipelining
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE)
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE)
Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast<C_type*>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
Tensor acc =
make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
warpgroup_fence_operand(acc);
warpgroup_arrive();
......@@ -103,7 +137,9 @@ class GemmTensorOp {
}
warpgroup_commit_batch();
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
if constexpr (wg_wait >= 0) {
warpgroup_wait<wg_wait>();
}
warpgroup_fence_operand(acc);
// warpgroup_fence_operand(acc);
// warpgroup_arrive();
......@@ -115,25 +151,31 @@ class GemmTensorOp {
// warpgroup_fence_operand(acc);
}
template <int wg_wait=0>
static CUTE_DEVICE void body_rs(A_type_raw* pA, B_type_raw* pB, C_type_raw* pC) {
template <int wg_wait = 0>
static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB,
C_type_raw *pC) {
// TODO: Move bar.sync out of body_rs
// asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(num_warp_m * num_warp_n * 32));
// asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(num_warp_m * num_warp_n *
// 32));
const int tid = threadIdx.x;
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type*>(pB)), SmemLayoutB{});
auto tiled_mma =
make_tiled_mma(GMMA::rs_op_selector<A_type, B_type, C_type, Shape<Int<M>, Int<N / num_warp_n>, Int<K>>,
GmmaMajorA, GmmaMajorB>(),
Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{});
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
SmemLayoutB{});
auto tiled_mma = make_tiled_mma(
GMMA::rs_op_selector<A_type, B_type, C_type,
Shape<Int<M>, Int<N / num_warp_n>, Int<K>>,
GmmaMajorA, GmmaMajorB>(),
Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{});
auto thr_mma = tiled_mma.get_thread_slice(tid);
// Allocate registers for pipelining
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE)
Tensor tCrA = make_tensor(make_rmem_ptr(reinterpret_cast<A_type*>(pA)),
partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast<C_type*>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE)
Tensor tCrA =
make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
Tensor acc =
make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
warpgroup_fence_operand(tCrA);
warpgroup_fence_operand(acc);
......@@ -146,7 +188,9 @@ class GemmTensorOp {
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
if constexpr (wg_wait >= 0) {
warpgroup_wait<wg_wait>();
}
warpgroup_fence_operand(acc);
warpgroup_fence_operand(tCrA);
......@@ -156,57 +200,63 @@ class GemmTensorOp {
// gemm(tiled_mma, tCrA(_, _, _), tCrB(_, _, _), acc);
// warpgroup_commit_batch();
// if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
// warpgroup_fence_operand(acc);
}
};
} // namespace cute
} // namespace cute
namespace tl {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B, int wg_wait=0,
typename A_type, typename B_type, typename C_type>
TL_DEVICE void gemm_ss(A_type* pA, B_type* pB, C_type* accum) {
using MMA =
cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, A_type, B_type, C_type>;
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, int wg_wait = 0, typename A_type, typename B_type,
typename C_type>
TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>;
MMA::body<wg_wait>(pA, pB, accum);
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B, int wg_wait=0,
typename A_type, typename B_type, typename C_type>
TL_DEVICE void gemm_rs(A_type* pA, B_type* pB, C_type* accum) {
using MMA =
cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, A_type, B_type, C_type>;
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, int wg_wait = 0, typename A_type, typename B_type,
typename C_type>
TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>;
MMA::body_rs<wg_wait>(pA, pB, accum);
}
template <int num_mma>
TL_DEVICE void wait_wgmma() {
template <int num_mma> TL_DEVICE void wait_wgmma() {
warpgroup_wait<num_mma>();
}
template <int NumMmaThreads>
TL_DEVICE void warp_scheduler_barrier_sync() {
cutlass::arch::NamedBarrier::sync(
NumMmaThreads,
cutlass::canonical_warp_group_idx() /*id*/);
template <int NumMmaThreads> TL_DEVICE void warp_scheduler_barrier_sync() {
cutlass::arch::NamedBarrier::sync(NumMmaThreads,
cutlass::canonical_warp_group_idx() /*id*/);
}
template <int NumMmaThreads>
TL_DEVICE void warp_scheduler_barrier_arrive() {
template <int NumMmaThreads> TL_DEVICE void warp_scheduler_barrier_arrive() {
static_assert(NumMmaThreads == 256 || NumMmaThreads == 384);
if constexpr (NumMmaThreads == 256) {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, (1 - cutlass::canonical_warp_group_idx()) /*id*/);
cutlass::arch::NamedBarrier::arrive(
NumMmaThreads, (1 - cutlass::canonical_warp_group_idx()) /*id*/);
} else {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, (cutlass::canonical_warp_group_idx() <= 1 ? cutlass::canonical_warp_group_idx() + 1 : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/);
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, (cutlass::canonical_warp_group_idx() <= 0 ? cutlass::canonical_warp_group_idx() + 2 : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/);
cutlass::arch::NamedBarrier::arrive(
NumMmaThreads,
(cutlass::canonical_warp_group_idx() <= 1
? cutlass::canonical_warp_group_idx() + 1
: cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/);
cutlass::arch::NamedBarrier::arrive(
NumMmaThreads,
(cutlass::canonical_warp_group_idx() <= 0
? cutlass::canonical_warp_group_idx() + 2
: cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/);
}
}
template <int NumMmaThreads>
TL_DEVICE void mma_init() {
template <int NumMmaThreads> TL_DEVICE void mma_init() {
static_assert(NumMmaThreads == 256 || NumMmaThreads == 384);
if (cutlass::canonical_warp_group_idx() > 0) {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 0);
......@@ -217,4 +267,4 @@ TL_DEVICE void mma_init() {
}
}
}
} // namespace tl
} // namespace tl
......@@ -6,97 +6,118 @@
namespace tl {
TL_DEVICE void ptx_ldmatrix_x1(void const* const smem_ptr, void* const local_ptr) {
TL_DEVICE void ptx_ldmatrix_x1(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t* value = reinterpret_cast<int32_t*>(local_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n"
: "=r"(value[0])
: "r"(smem_int_ptr));
}
TL_DEVICE void ptx_ldmatrix_x2(void const* const smem_ptr, void* const local_ptr) {
TL_DEVICE void ptx_ldmatrix_x2(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t* value = reinterpret_cast<int32_t*>(local_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
: "=r"(value[0]), "=r"(value[1])
: "r"(smem_int_ptr));
}
TL_DEVICE void ptx_ldmatrix_x4(void const* const smem_ptr, void* const local_ptr) {
TL_DEVICE void ptx_ldmatrix_x4(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t* value = reinterpret_cast<int32_t*>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3])
: "r"(smem_int_ptr));
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3])
: "r"(smem_int_ptr));
}
TL_DEVICE void ptx_ldmatrix_x1_trans(void const* const smem_ptr, void* const local_ptr) {
TL_DEVICE void ptx_ldmatrix_x1_trans(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t* value = reinterpret_cast<int32_t*>(local_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n"
: "=r"(value[0])
: "r"(smem_int_ptr));
}
TL_DEVICE void ptx_ldmatrix_x2_trans(void const* const smem_ptr, void* const local_ptr) {
TL_DEVICE void ptx_ldmatrix_x2_trans(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t* value = reinterpret_cast<int32_t*>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n"
: "=r"(value[0]), "=r"(value[1])
: "r"(smem_int_ptr));
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile(
"ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n"
: "=r"(value[0]), "=r"(value[1])
: "r"(smem_int_ptr));
}
TL_DEVICE void ptx_ldmatrix_x4_trans(void const* const smem_ptr, void* const local_ptr) {
TL_DEVICE void ptx_ldmatrix_x4_trans(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t* value = reinterpret_cast<int32_t*>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3])
: "r"(smem_int_ptr));
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile(
"ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3])
: "r"(smem_int_ptr));
}
TL_DEVICE void ptx_stmatrix_x1(void const* const smem_ptr, const int32_t& value0) {
TL_DEVICE void ptx_stmatrix_x1(void const *const smem_ptr,
const int32_t &value0) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" ::"r"(smem_int_ptr),
asm volatile("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" ::"r"(
smem_int_ptr),
"r"(value0));
}
TL_DEVICE void ptx_stmatrix_x2(void const* const smem_ptr, const int32_t& value0,
const int32_t& value1) {
TL_DEVICE void ptx_stmatrix_x2(void const *const smem_ptr,
const int32_t &value0, const int32_t &value1) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" ::"r"(smem_int_ptr),
"r"(value0), "r"(value1));
asm volatile(
"stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" ::"r"(
smem_int_ptr),
"r"(value0), "r"(value1));
}
TL_DEVICE void ptx_stmatrix_x4(void const* const smem_ptr, const int32_t& value0,
const int32_t& value1, const int32_t& value2,
const int32_t& value3) {
TL_DEVICE void ptx_stmatrix_x4(void const *const smem_ptr,
const int32_t &value0, const int32_t &value1,
const int32_t &value2, const int32_t &value3) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile(
"stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" ::"r"(smem_int_ptr),
"stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" ::
"r"(smem_int_ptr),
"r"(value0), "r"(value1), "r"(value2), "r"(value3));
}
TL_DEVICE void ptx_stmatrix_x1_trans(void const* const smem_ptr, const int32_t& value0) {
TL_DEVICE void ptx_stmatrix_x1_trans(void const *const smem_ptr,
const int32_t &value0) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};\n" ::"r"(smem_int_ptr),
"r"(value0));
asm volatile(
"stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};\n" ::"r"(
smem_int_ptr),
"r"(value0));
}
TL_DEVICE void ptx_stmatrix_x2_trans(void const* const smem_ptr, const int32_t& value0,
const int32_t& value1) {
TL_DEVICE void ptx_stmatrix_x2_trans(void const *const smem_ptr,
const int32_t &value0,
const int32_t &value1) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile(
"stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};\n" ::"r"(smem_int_ptr),
"stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};\n" ::"r"(
smem_int_ptr),
"r"(value0), "r"(value1));
}
TL_DEVICE void ptx_stmatrix_x4_trans(void const* const smem_ptr, const int32_t& value0,
const int32_t& value1, const int32_t& value2,
const int32_t& value3) {
TL_DEVICE void ptx_stmatrix_x4_trans(void const *const smem_ptr,
const int32_t &value0,
const int32_t &value1,
const int32_t &value2,
const int32_t &value3) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" ::"r"(
smem_int_ptr),
asm volatile("stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, "
"%3, %4};\n" ::"r"(smem_int_ptr),
"r"(value0), "r"(value1), "r"(value2), "r"(value3));
}
} // namespace tl
\ No newline at end of file
} // namespace tl
\ No newline at end of file
......@@ -7,34 +7,29 @@
namespace tl {
struct SumOp {
template <typename T>
TL_DEVICE T operator()(T const& x, T const& y) {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return x + y;
}
};
struct MaxOp {
template <typename T>
TL_DEVICE T operator()(T const& x, T const& y) {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return cutlass::fast_max(x, y);
}
};
struct MinOp {
template <typename T>
TL_DEVICE T operator()(T const& x, T const& y) {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return cutlass::fast_min(x, y);
}
};
template <class Reducer, int threads, int scale>
struct AllReduce {
static_assert(threads == 1024 or threads == 512 or threads == 256 or threads == 128 or
threads == 64 or threads == 32 or threads == 16 or threads == 8 or threads == 4 or
threads == 2);
template <class Reducer, int threads, int scale> struct AllReduce {
static_assert(threads == 1024 or threads == 512 or threads == 256 or
threads == 128 or threads == 64 or threads == 32 or
threads == 16 or threads == 8 or threads == 4 or threads == 2);
static_assert(threads % scale == 0);
template <typename T>
static TL_DEVICE T run(T x, T* red_buf = nullptr) {
template <typename T> static TL_DEVICE T run(T x, T *red_buf = nullptr) {
constexpr int offset = threads / 2;
if constexpr (offset >= 32) {
__syncthreads();
......@@ -54,4 +49,4 @@ struct AllReduce {
}
};
} // namespace tl
} // namespace tl
......@@ -6,8 +6,7 @@
namespace tl {
template <int panel_width>
TL_DEVICE dim3 rasterization2DRow() {
template <int panel_width> TL_DEVICE dim3 rasterization2DRow() {
const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x;
const unsigned int grid_size = gridDim.x * gridDim.y;
const unsigned int panel_size = panel_width * gridDim.x;
......@@ -15,15 +14,17 @@ TL_DEVICE dim3 rasterization2DRow() {
const unsigned int panel_idx = block_idx / panel_size;
const unsigned int total_panel = cutlass::ceil_div(grid_size, panel_size);
const unsigned int stride =
panel_idx + 1 < total_panel ? panel_width : (grid_size - panel_idx * panel_size) / gridDim.x;
const unsigned int col_idx =
(panel_idx & 1) ? gridDim.x - 1 - panel_offset / stride : panel_offset / stride;
panel_idx + 1 < total_panel
? panel_width
: (grid_size - panel_idx * panel_size) / gridDim.x;
const unsigned int col_idx = (panel_idx & 1)
? gridDim.x - 1 - panel_offset / stride
: panel_offset / stride;
const unsigned int row_idx = panel_offset % stride + panel_idx * panel_width;
return {col_idx, row_idx, blockIdx.z};
}
template <int panel_width>
TL_DEVICE dim3 rasterization2DColumn() {
template <int panel_width> TL_DEVICE dim3 rasterization2DColumn() {
const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x;
const unsigned int grid_size = gridDim.x * gridDim.y;
const unsigned int panel_size = panel_width * gridDim.y;
......@@ -31,11 +32,14 @@ TL_DEVICE dim3 rasterization2DColumn() {
const unsigned int panel_idx = block_idx / panel_size;
const unsigned int total_panel = cutlass::ceil_div(grid_size, panel_size);
const unsigned int stride =
panel_idx + 1 < total_panel ? panel_width : (grid_size - panel_idx * panel_size) / gridDim.y;
const unsigned int row_idx =
(panel_idx & 1) ? gridDim.y - 1 - panel_offset / stride : panel_offset / stride;
panel_idx + 1 < total_panel
? panel_width
: (grid_size - panel_idx * panel_size) / gridDim.y;
const unsigned int row_idx = (panel_idx & 1)
? gridDim.y - 1 - panel_offset / stride
: panel_offset / stride;
const unsigned int col_idx = panel_offset % stride + panel_idx * panel_width;
return {col_idx, row_idx, blockIdx.z};
}
} // namespace tl
} // namespace tl
......@@ -2,10 +2,10 @@
// Licensed under the MIT License.
#pragma once
#include <hip/hip_runtime.h>
#include <ck_tile/core.hpp>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <rocwmma/rocwmma.hpp>
#include <ck_tile/core.hpp>
using ck_tile::half_t;
......@@ -36,12 +36,16 @@ using ck_tile::half_t;
using float16_t = _Float16;
using float16x2 = __attribute__((__vector_size__(2 * sizeof(float16_t)))) float16_t;
using float16x4 = __attribute__((__vector_size__(4 * sizeof(float16_t)))) float16_t;
using float16x8 = __attribute__((__vector_size__(8 * sizeof(float16_t)))) float16_t;
using float16x16 = __attribute__((__vector_size__(16 * sizeof(float16_t)))) float16_t;
using float16x2 =
__attribute__((__vector_size__(2 * sizeof(float16_t)))) float16_t;
using float16x4 =
__attribute__((__vector_size__(4 * sizeof(float16_t)))) float16_t;
using float16x8 =
__attribute__((__vector_size__(8 * sizeof(float16_t)))) float16_t;
using float16x16 =
__attribute__((__vector_size__(16 * sizeof(float16_t)))) float16_t;
using int32x4 = __attribute__((__vector_size__(4 * sizeof(int)))) int;
using int32x4 = __attribute__((__vector_size__(4 * sizeof(int)))) int;
using float32x4 = __attribute__((__vector_size__(4 * sizeof(float)))) float;
using float32x16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
......@@ -49,7 +53,7 @@ using int8x4 = __attribute__((__vector_size__(4 * sizeof(int8_t)))) int8_t;
// Pack two half_t values.
TL_DEVICE unsigned __pack_half2(const half_t x, const half_t y) {
unsigned v0 = *((unsigned short*)&x);
unsigned v1 = *((unsigned short*)&y);
unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}
......@@ -16,12 +16,13 @@ using index_t = u32;
using ck_tile::int32x4_t;
struct __attribute__((packed)) buffer_resource {
const void* ptr;
const void *ptr;
uint32_t range;
uint32_t config;
};
CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t size = 0xffffffff) {
CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void *ptr,
uint32_t size = 0xffffffff) {
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
r.x = __builtin_amdgcn_readfirstlane(r.x);
......@@ -56,48 +57,56 @@ __device__ void async_gld_sld_fence(index_t cnt) {
__device__ void wave_barrier() { asm volatile("s_barrier" : : : "memory"); }
template <int N = 0>
TL_DEVICE void cp_async_wait() {
template <int N = 0> TL_DEVICE void cp_async_wait() {
async_gld_fence(N);
// or
// async_gld_sld_fence(N);
}
template <bool pre_nop = false>
CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem, int32x4_t rsrc, index_t voffset) {
auto const lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>(smem)));
asm volatile(
"s_mov_b32 m0, %0; \n\t"
"buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr),
"v"(voffset), "s"(rsrc)
: "memory");
CK_TILE_DEVICE void async_buffer_load_dword_v(void *smem, int32x4_t rsrc,
index_t voffset) {
auto const lds_ptr_sgpr =
__builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>(smem)));
asm volatile("s_mov_b32 m0, %0; \n\t"
"buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr),
"v"(voffset), "s"(rsrc)
: "memory");
}
template <int N>
TL_DEVICE void cp_async_gs(void* lds_base_ptr, void* global_base_ptr) {
if constexpr(N == 16) {
*(uint4*)lds_base_ptr = *(uint4*)global_base_ptr;
} else if constexpr(N == 8) {
*(uint2*)lds_base_ptr = *(uint2*)global_base_ptr;
} else if constexpr(N == 4) {
async_buffer_load_dword_v(lds_base_ptr, make_wave_buffer_resource(((int32_t *)global_base_ptr) - threadIdx.x), threadIdx.x * N /*assume 4 bytes*/);
TL_DEVICE void cp_async_gs(void *lds_base_ptr, void *global_base_ptr) {
if constexpr (N == 16) {
*(uint4 *)lds_base_ptr = *(uint4 *)global_base_ptr;
} else if constexpr (N == 8) {
*(uint2 *)lds_base_ptr = *(uint2 *)global_base_ptr;
} else if constexpr (N == 4) {
async_buffer_load_dword_v(
lds_base_ptr,
make_wave_buffer_resource(((int32_t *)global_base_ptr) - threadIdx.x),
threadIdx.x * N /*assume 4 bytes*/);
}
}
template <int N>
TL_DEVICE void cp_async_gs_conditional(void* lds_base_ptr, void* global_base_ptr, bool cond) {
if constexpr(N == 16){
*(uint4*)lds_base_ptr = cond? *(uint4*)global_base_ptr: make_uint4(0,0,0,0);
}else if constexpr(N == 8){
*(uint2*)lds_base_ptr = cond? *(uint2*)global_base_ptr: make_uint2(0,0);
}else{
TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr,
void *global_base_ptr, bool cond) {
if constexpr (N == 16) {
*(uint4 *)lds_base_ptr =
cond ? *(uint4 *)global_base_ptr : make_uint4(0, 0, 0, 0);
} else if constexpr (N == 8) {
*(uint2 *)lds_base_ptr =
cond ? *(uint2 *)global_base_ptr : make_uint2(0, 0);
} else {
if (cond) {
async_buffer_load_dword_v(lds_base_ptr, make_wave_buffer_resource(((int32_t *)global_base_ptr) - threadIdx.x), threadIdx.x * N /*assume 4 bytes*/);
}else{
*(uint4*)lds_base_ptr = make_uint4(0,0,0,0);
async_buffer_load_dword_v(
lds_base_ptr,
make_wave_buffer_resource(((int32_t *)global_base_ptr) - threadIdx.x),
threadIdx.x * N /*assume 4 bytes*/);
} else {
*(uint4 *)lds_base_ptr = make_uint4(0, 0, 0, 0);
}
}
}
} // namespace tl
} // namespace tl
......@@ -6,12 +6,12 @@
namespace tl {
// ref to bitblas/tl/mfma_macro_generator.py::kPack
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool TransposeA, bool TransposeB, int kPack,
typename A_type, typename B_type, typename C_type, typename AccDataType = float>
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool TransposeA,
bool TransposeB, int kPack, typename A_type, typename B_type,
typename C_type, typename AccDataType = float>
class GemmTensorOp {
public:
public:
static constexpr int micro_size_x = 16;
static constexpr int micro_size_y = 16;
static constexpr int micro_size_k = 16;
......@@ -28,7 +28,8 @@ class GemmTensorOp {
static constexpr int warp_rows = M_Tile / (block_row_warps * micro_size_x);
static constexpr int warp_cols = N_Tile / (block_col_warps * micro_size_y);
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen
// part.
static constexpr bool kPadA = true;
static constexpr bool kPadB = true;
static constexpr bool kPadC = true;
......@@ -37,12 +38,16 @@ class GemmTensorOp {
static constexpr int warp_size = 64;
TL_DEVICE static constexpr auto reverse_index_map(int thread_id, int local_id) {
return std::make_pair(thread_id % 16, (thread_id / 16) * (4 * kPack) + local_id);
TL_DEVICE static constexpr auto reverse_index_map(int thread_id,
int local_id) {
return std::make_pair(thread_id % 16,
(thread_id / 16) * (4 * kPack) + local_id);
}
TL_DEVICE static constexpr auto reverse_index_map_transposed(int thread_id, int local_id) {
return std::make_pair((thread_id / 16) * (4 * kPack) + local_id, thread_id % 16);
TL_DEVICE static constexpr auto reverse_index_map_transposed(int thread_id,
int local_id) {
return std::make_pair((thread_id / 16) * (4 * kPack) + local_id,
thread_id % 16);
}
/*
......@@ -62,7 +67,8 @@ class GemmTensorOp {
const int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit;
const int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
const int maxPhase = std::min(SIMDWidth / perPhase, innerDimLength / vecSize);
const int maxPhase =
std::min(SIMDWidth / perPhase, innerDimLength / vecSize);
const int phase = (row / perPhase) % maxPhase;
const int colOffSwizzled = (((col / vecSize) ^ phase) * vecSize);
......@@ -73,16 +79,19 @@ class GemmTensorOp {
}
template <int continuous = 32, int element_size = 2>
TL_DEVICE static constexpr auto make_layout_padded(const int row, const int col) {
TL_DEVICE static constexpr auto make_layout_padded(const int row,
const int col) {
return std::make_pair(row, col);
}
template <int continuous = 32, int element_size = 2>
TL_DEVICE static constexpr auto make_swizzle_layout(const int row, const int col) {
TL_DEVICE static constexpr auto make_swizzle_layout(const int row,
const int col) {
constexpr auto vector_size = BANK_SIZE_BYTES / (element_size * 8);
if (continuous % (vector_size * 4) == 0) {
auto [n_row, n_col] = make_mfma_swizzle_layout<continuous, element_size>(row, col);
auto [n_row, n_col] =
make_mfma_swizzle_layout<continuous, element_size>(row, col);
return n_row * continuous + n_col;
} else {
auto [n_row, n_col] = make_layout_padded(row, col);
......@@ -93,7 +102,8 @@ class GemmTensorOp {
}
}
static TL_DEVICE void body(A_type* A_shared, B_type* B_shared, C_type* C_local) {
static TL_DEVICE void body(A_type *A_shared, B_type *B_shared,
C_type *C_local) {
auto tid = threadIdx.x;
auto warp_id = tid / warp_size;
auto warp_n = warp_id / block_row_warps;
......@@ -122,7 +132,8 @@ class GemmTensorOp {
for (int local_id = 0; local_id < (kPack * local_size_a); local_id++) {
auto [row, col] = reverse_index_map(lane_id, local_id);
A_local[i * kPack * local_size_a + local_id] =
A_shared[make_swizzle_layout<last_dim_a, sizeof(A_type)>(l + row, r + col)];
A_shared[make_swizzle_layout<last_dim_a, sizeof(A_type)>(
l + row, r + col)];
}
}
......@@ -133,7 +144,8 @@ class GemmTensorOp {
for (int local_id = 0; local_id < (kPack * local_size_b); local_id++) {
auto [row, col] = reverse_index_map(lane_id, local_id);
B_local[j * kPack * local_size_b + local_id] =
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(l + row, r + col)];
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
l + row, r + col)];
}
}
......@@ -141,17 +153,19 @@ class GemmTensorOp {
for (int kp = 0; kp < kPack; kp++) {
for (int i = 0; i < warp_rows; ++i) {
for (int j = 0; j < warp_cols; ++j) {
*(((float32x4*)C_local) + ((i * warp_cols) + j)) = __builtin_amdgcn_mfma_f32_16x16x16f16(
*(((float16x4*)B_local) + j * kPack + kp),
*(((float16x4*)A_local) + i * kPack + kp),
*(((float32x4*)C_local) + ((i * warp_cols) + j)), 0, 0, 0);
*(((float32x4 *)C_local) + ((i * warp_cols) + j)) =
__builtin_amdgcn_mfma_f32_16x16x16f16(
*(((float16x4 *)B_local) + j * kPack + kp),
*(((float16x4 *)A_local) + i * kPack + kp),
*(((float32x4 *)C_local) + ((i * warp_cols) + j)), 0, 0, 0);
}
}
}
}
}
static TL_DEVICE void body_rs(A_type* A_local, B_type* B_shared, C_type* C_local) {
static TL_DEVICE void body_rs(A_type *A_local, B_type *B_shared,
C_type *C_local) {
auto tid = threadIdx.x;
auto warp_id = tid / warp_size;
auto warp_n = warp_id / block_row_warps;
......@@ -179,7 +193,8 @@ class GemmTensorOp {
for (int local_id = 0; local_id < kPack * local_size_b; local_id++) {
auto [row, col] = reverse_index_map(lane_id, local_id);
B_local[j * local_size_b + local_id] =
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(l + row, r + col)];
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
l + row, r + col)];
}
}
......@@ -187,9 +202,12 @@ class GemmTensorOp {
for (int kp = 0; kp < kPack; kp++) {
for (int i = 0; i < warp_rows; ++i) {
for (int j = 0; j < warp_cols; ++j) {
*(((float32x4*)C_local) + ((i * warp_cols) + j)) = __builtin_amdgcn_mfma_f32_16x16x16f16(
*(((float16x4*)B_local) + j * kPack + kp), *(((float16x4*)A_local) + ki * warp_rows * kPack + i * kPack + kp),
*(((float32x4*)C_local) + ((i * warp_cols) + j)), 0, 0, 0);
*(((float32x4 *)C_local) + ((i * warp_cols) + j)) =
__builtin_amdgcn_mfma_f32_16x16x16f16(
*(((float16x4 *)B_local) + j * kPack + kp),
*(((float16x4 *)A_local) + ki * warp_rows * kPack +
i * kPack + kp),
*(((float32x4 *)C_local) + ((i * warp_cols) + j)), 0, 0, 0);
}
}
}
......@@ -197,24 +215,26 @@ class GemmTensorOp {
}
};
} // namespace tl
} // namespace tl
namespace tl {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B, int kPack,
typename A_type, typename B_type, typename C_type>
TL_DEVICE void gemm_ss(A_type* pA, B_type* pB, C_type* accum) {
using Compute =
GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, kPack, A_type, B_type, C_type>;
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, int kPack, typename A_type, typename B_type,
typename C_type>
TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
using Compute = GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, kPack, A_type, B_type, C_type>;
Compute::body(pA, pB, accum);
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B, int kPack,
typename A_type, typename B_type, typename C_type>
TL_DEVICE void gemm_rs(A_type* pA, B_type* pB, C_type* accum) {
using Compute =
GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, kPack, A_type, B_type, C_type>;
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, int kPack, typename A_type, typename B_type,
typename C_type>
TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
using Compute = GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, kPack, A_type, B_type, C_type>;
Compute::body_rs(pA, pB, accum);
}
} // namespace tl
} // namespace tl
......@@ -7,35 +7,30 @@
namespace tl {
struct SumOp {
template <typename T>
TL_DEVICE T operator()(T const& x, T const& y) {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return x + y;
}
};
struct MaxOp {
template <typename T>
TL_DEVICE T operator()(T const& x, T const& y) {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return ck_tile::max(x, y);
}
};
struct MinOp {
template <typename T>
TL_DEVICE T operator()(T const& x, T const& y) {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return ck_tile::min(x, y);
}
};
template <class Reducer, int threads, int scale>
struct AllReduce {
static_assert(threads == 1024 || threads == 512 || threads == 256 || threads == 128 ||
threads == 64 || threads == 32 || threads == 16 || threads == 8 || threads == 4 ||
threads == 2);
template <class Reducer, int threads, int scale> struct AllReduce {
static_assert(threads == 1024 || threads == 512 || threads == 256 ||
threads == 128 || threads == 64 || threads == 32 ||
threads == 16 || threads == 8 || threads == 4 || threads == 2);
static_assert(threads % scale == 0);
template <typename T>
static __device__ T run(T x, T* red_buf = nullptr) {
template <typename T> static __device__ T run(T x, T *red_buf = nullptr) {
constexpr int offset = threads / 2;
constexpr int warpSize = 64;
......@@ -55,4 +50,4 @@ struct AllReduce {
}
};
} // namespace tl
} // namespace tl
......@@ -6,8 +6,7 @@
namespace tl {
template <int panel_width>
TL_DEVICE dim3 rasterization2DRow() {
template <int panel_width> TL_DEVICE dim3 rasterization2DRow() {
auto ceil_div = [](int a, int b) { return (a + b - 1) / b; };
const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x;
const unsigned int grid_size = gridDim.x * gridDim.y;
......@@ -16,15 +15,17 @@ TL_DEVICE dim3 rasterization2DRow() {
const unsigned int panel_idx = block_idx / panel_size;
const unsigned int total_panel = ceil_div(grid_size, panel_size);
const unsigned int stride =
panel_idx + 1 < total_panel ? panel_width : (grid_size - panel_idx * panel_size) / gridDim.x;
const unsigned int col_idx =
(panel_idx & 1) ? gridDim.x - 1 - panel_offset / stride : panel_offset / stride;
panel_idx + 1 < total_panel
? panel_width
: (grid_size - panel_idx * panel_size) / gridDim.x;
const unsigned int col_idx = (panel_idx & 1)
? gridDim.x - 1 - panel_offset / stride
: panel_offset / stride;
const unsigned int row_idx = panel_offset % stride + panel_idx * panel_width;
return {col_idx, row_idx, blockIdx.z};
}
template <int panel_width>
TL_DEVICE dim3 rasterization2DColumn() {
template <int panel_width> TL_DEVICE dim3 rasterization2DColumn() {
auto ceil_div = [](int a, int b) { return (a + b - 1) / b; };
const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x;
const unsigned int grid_size = gridDim.x * gridDim.y;
......@@ -33,11 +34,14 @@ TL_DEVICE dim3 rasterization2DColumn() {
const unsigned int panel_idx = block_idx / panel_size;
const unsigned int total_panel = ceil_div(grid_size, panel_size);
const unsigned int stride =
panel_idx + 1 < total_panel ? panel_width : (grid_size - panel_idx * panel_size) / gridDim.y;
const unsigned int row_idx =
(panel_idx & 1) ? gridDim.y - 1 - panel_offset / stride : panel_offset / stride;
panel_idx + 1 < total_panel
? panel_width
: (grid_size - panel_idx * panel_size) / gridDim.y;
const unsigned int row_idx = (panel_idx & 1)
? gridDim.y - 1 - panel_offset / stride
: panel_offset / stride;
const unsigned int col_idx = panel_offset % stride + panel_idx * panel_width;
return {col_idx, row_idx, blockIdx.z};
}
} // namespace tl
} // namespace tl
......@@ -31,15 +31,17 @@ namespace tvm {
namespace tir {
class ClusterPlanner {
public:
static PrimFunc Substitute(PrimFunc& f) {
public:
static PrimFunc Substitute(PrimFunc &f) {
// Step 1: Collect the read region of the function
Map<Var, Buffer> buffer_data_to_buffer_;
for (const auto& [_, buffer] : f->buffer_map) {
for (const auto &[_, buffer] : f->buffer_map) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ f->body);
Array<Array<BufferRegion>> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
/*body*/ f->body);
Array<Array<BufferRegion>> access =
GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
auto reads = access[0];
BlockIdxVisitor blockIdx_visitor;
......@@ -47,20 +49,22 @@ class ClusterPlanner {
auto dom_map = blockIdx_visitor.dom_map_;
// Step 2: Collect mem reuse count for clustering on each dimension.
std::unordered_map<const IterVarNode*, size_t> mem_reuse_count;
for (auto iv : dom_map) mem_reuse_count[iv] = 0;
std::unordered_map<const IterVarNode *, size_t> mem_reuse_count;
for (auto iv : dom_map)
mem_reuse_count[iv] = 0;
for (const auto& buffer_region : reads) {
for (const auto &buffer_region : reads) {
PrimExpr size = buffer_region->buffer->dtype.bits();
RegionVisitor visitor;
for (const auto& range : buffer_region->region) {
for (const auto &range : buffer_region->region) {
size = size * range->extent;
visitor(range->min);
}
size = arith::Analyzer().Simplify(size);
if (auto imm = size.as<IntImmNode>()) {
for (auto iv : dom_map) {
if (visitor.seen_.count(iv->var.get()) == 0) mem_reuse_count[iv] += imm->value;
if (visitor.seen_.count(iv->var.get()) == 0)
mem_reuse_count[iv] += imm->value;
}
}
}
......@@ -70,7 +74,8 @@ class ClusterPlanner {
String cluster_tag;
for (auto iv : dom_map) {
if (auto extent = iv->dom->extent.as<IntImmNode>()) {
if (extent->value % cluster_size_ == 0 && mem_reuse_count[iv] > mem_reuse_max) {
if (extent->value % cluster_size_ == 0 &&
mem_reuse_count[iv] > mem_reuse_max) {
cluster_tag = iv->thread_tag;
mem_reuse_max = mem_reuse_count[iv];
}
......@@ -78,27 +83,28 @@ class ClusterPlanner {
}
if (mem_reuse_max > 0) {
cluster_tag = "clusterIdx" + String(cluster_tag.c_str() + strlen("blockIdx"));
cluster_tag =
"clusterIdx" + String(cluster_tag.c_str() + strlen("blockIdx"));
return WithAttr(f, cluster_tag, Integer(cluster_size_));
} else {
return f;
}
}
private:
private:
ClusterPlanner() = default;
class RegionVisitor : public ExprVisitor {
public:
public:
RegionVisitor(){};
void VisitExpr_(const VarNode* var) { seen_.insert(var); }
std::unordered_set<const VarNode*> seen_;
void VisitExpr_(const VarNode *var) { seen_.insert(var); }
std::unordered_set<const VarNode *> seen_;
};
class BlockIdxVisitor : public StmtVisitor {
public:
public:
BlockIdxVisitor(){};
void VisitStmt_(const AttrStmtNode* attr) final {
void VisitStmt_(const AttrStmtNode *attr) final {
if (attr->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(attr->node);
String tag = iv->thread_tag;
......@@ -108,7 +114,7 @@ class ClusterPlanner {
StmtVisitor::VisitStmt_(attr);
}
/*! \brief The map from vars to blockidx extents. */
std::unordered_set<const IterVarNode*> dom_map_;
std::unordered_set<const IterVarNode *> dom_map_;
};
/*! \brief Currently set the plossible cluster size as 2 */
......@@ -126,8 +132,9 @@ tvm::transform::Pass ClusterPlanning() {
return CreatePrimFuncPass(pass_func, 0, "tl.ClusterPlanning", {});
}
TVM_REGISTER_GLOBAL("tl.transform.ClusterPlanning").set_body_typed(ClusterPlanning);
} // namespace transform
TVM_REGISTER_GLOBAL("tl.transform.ClusterPlanning")
.set_body_typed(ClusterPlanning);
} // namespace transform
} // namespace tir
} // namespace tvm
} // namespace tir
} // namespace tvm
......@@ -32,10 +32,10 @@
#include <queue>
#include "arith/ir_mutator_with_analyzer.h"
#include "../../op/parallel.h"
#include "../loop_partition.h"
#include "../loop_vectorize.h"
#include "arith/ir_mutator_with_analyzer.h"
namespace tvm {
namespace tl {
......@@ -44,15 +44,15 @@ using namespace tir;
using arith::IRMutatorWithAnalyzer;
class FragmentAccessDetector : public StmtExprVisitor {
public:
public:
FragmentAccessDetector() = default;
void Collect(Stmt stmt) { VisitStmt(stmt); }
bool HasFragmentAccess() { return has_fragment_access_; }
private:
void VisitExpr_(const BufferLoadNode* op) final {
private:
void VisitExpr_(const BufferLoadNode *op) final {
// Check if the buffer is in global scope
if (IsFragmentBuffer(op->buffer)) {
has_fragment_access_ = true;
......@@ -60,7 +60,7 @@ class FragmentAccessDetector : public StmtExprVisitor {
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode* op) final {
void VisitStmt_(const BufferStoreNode *op) final {
// Check if the buffer is in global scope
if (IsFragmentBuffer(op->buffer)) {
has_fragment_access_ = true;
......@@ -69,8 +69,9 @@ class FragmentAccessDetector : public StmtExprVisitor {
}
// Helper function to determine if a buffer is local.fragment
bool IsFragmentBuffer(const Buffer& buffer) {
// The storage scope is often encoded in the buffer->data var name or associated attributes.
bool IsFragmentBuffer(const Buffer &buffer) {
// The storage scope is often encoded in the buffer->data var name or
// associated attributes.
String scope = buffer.scope();
return scope == "local.fragment";
}
......@@ -87,23 +88,25 @@ class FragmentAccessDetector : public StmtExprVisitor {
* Once fused, a single loop variable will replace the chain, and the
* original loop variables will be derived by division and modulo operations.
*
* This can be helpful for inferring layout for the fragment in a subsequent pass.
* This can be helpful for inferring layout for the fragment in a subsequent
* pass.
*/
class ParallelLoopFuser : public IRMutatorWithAnalyzer {
public:
public:
static Stmt Fuse(Stmt stmt) {
arith::Analyzer analyzer;
ParallelLoopFuser substituter(&analyzer);
return substituter.VisitStmt(stmt);
}
private:
ParallelLoopFuser(arith::Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer) {};
private:
ParallelLoopFuser(arith::Analyzer *analyzer)
: IRMutatorWithAnalyzer(analyzer){};
Stmt VisitStmt_(const ForNode* op) final {
Stmt VisitStmt_(const ForNode *op) final {
// Gather consecutive parallel loops
std::vector<const ForNode*> loop_chain;
const ForNode* current = op;
std::vector<const ForNode *> loop_chain;
const ForNode *current = op;
// check if has fragment access
FragmentAccessDetector detector;
detector.Collect(op->body);
......@@ -113,11 +116,13 @@ class ParallelLoopFuser : public IRMutatorWithAnalyzer {
}
while (true) {
if (current->kind != ForKind::kParallel) break;
if (!is_zero(current->min)) break;
if (current->kind != ForKind::kParallel)
break;
if (!is_zero(current->min))
break;
loop_chain.push_back(current);
const ForNode* inner_for = current->body.as<ForNode>();
const ForNode *inner_for = current->body.as<ForNode>();
if (!inner_for) {
break;
}
......@@ -147,7 +152,7 @@ class ParallelLoopFuser : public IRMutatorWithAnalyzer {
Var fused_var(fused_name, DataType::Int(32));
// The body of the last loop in the chain:
const ForNode* innermost_loop = loop_chain.back();
const ForNode *innermost_loop = loop_chain.back();
Stmt body = innermost_loop->body;
// We need to substitute all loop variables in the chain.
......@@ -175,7 +180,8 @@ class ParallelLoopFuser : public IRMutatorWithAnalyzer {
extents.push_back(l->extent);
}
std::vector<PrimExpr> strides(loop_chain.size(), make_const(DataType::Int(32), 1));
std::vector<PrimExpr> strides(loop_chain.size(),
make_const(DataType::Int(32), 1));
for (int i = static_cast<int>(loop_chain.size()) - 2; i >= 0; i--) {
strides[i] = strides[i + 1] * extents[i + 1];
}
......@@ -189,8 +195,9 @@ class ParallelLoopFuser : public IRMutatorWithAnalyzer {
Map<Var, PrimExpr> var_map;
for (size_t i = 0; i < loop_chain.size(); i++) {
const ForNode* loop = loop_chain[i];
var_map.Set(loop->loop_var, analyzer_->Simplify(create_index_expr(static_cast<int>(i))));
const ForNode *loop = loop_chain[i];
var_map.Set(loop->loop_var,
analyzer_->Simplify(create_index_expr(static_cast<int>(i))));
}
// Perform the substitution
......@@ -203,5 +210,5 @@ class ParallelLoopFuser : public IRMutatorWithAnalyzer {
}
};
} // namespace tl
} // namespace tvm
} // namespace tl
} // namespace tvm
......@@ -32,10 +32,10 @@
#include <queue>
#include "arith/ir_mutator_with_analyzer.h"
#include "../../op/parallel.h"
#include "../loop_partition.h"
#include "../loop_vectorize.h"
#include "arith/ir_mutator_with_analyzer.h"
namespace tvm {
namespace tl {
......@@ -46,7 +46,8 @@ using namespace tir;
// Use the same code as tir.transform.vectorize_loop
inline PrimExpr CreateNewLanes(bool is_scalable, int lanes_or_vscale_factor) {
if (is_scalable) {
return Mul(Call(DataType::Int(32), builtin::vscale(), {}), lanes_or_vscale_factor);
return Mul(Call(DataType::Int(32), builtin::vscale(), {}),
lanes_or_vscale_factor);
} else {
return lanes_or_vscale_factor;
}
......@@ -58,7 +59,7 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) {
e.dtype().is_scalable_vector() == is_scalable)
return e;
if (const BroadcastNode* op = e.as<BroadcastNode>()) {
if (const BroadcastNode *op = e.as<BroadcastNode>()) {
ICHECK(op->dtype.is_scalable_vector() == is_scalable)
<< "Can't broadcast between scalable and fixed length vectors.";
int e_lanes = op->dtype.get_lanes_or_vscale_factor();
......@@ -68,40 +69,39 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) {
}
}
ICHECK(e.dtype().is_scalar()) << "Cannot broadcast lanes="
<< e.dtype().get_lanes_or_vscale_factor()
<< " is_scalable=" << e.dtype().is_scalable_vector() << " to "
<< lanes;
ICHECK(e.dtype().is_scalar())
<< "Cannot broadcast lanes=" << e.dtype().get_lanes_or_vscale_factor()
<< " is_scalable=" << e.dtype().is_scalable_vector() << " to " << lanes;
return Broadcast(e, CreateNewLanes(is_scalable, lanes));
}
// Rewrite vectorized allocation access
// This is necessary for making each vector component containing its own workspace.
// Originates from Halide's loop vectorizer
// This is necessary for making each vector component containing its own
// workspace. Originates from Halide's loop vectorizer
//
// s[i] = s[i * lanes + var]
//
// The same principle applies when using one thread to simulate multiple context.
// The same principle applies when using one thread to simulate multiple
// context.
//
class VecAllocAccess : public StmtExprMutator {
public:
VecAllocAccess(const VarNode* buf, Var var, PrimExpr var_lanes)
public:
VecAllocAccess(const VarNode *buf, Var var, PrimExpr var_lanes)
: buf_(buf), var_(var), var_lanes_(var_lanes) {}
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
return UpdateBufferAccess(load);
}
Stmt VisitStmt_(const BufferStoreNode* op) final {
Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
return UpdateBufferAccess(store);
}
private:
template <typename Node>
Node UpdateBufferAccess(Node node) {
private:
template <typename Node> Node UpdateBufferAccess(Node node) {
// Only update the buffer that's being replaced.
if (node->buffer->data.get() != buf_) {
return node;
......@@ -117,7 +117,8 @@ class VecAllocAccess : public StmtExprMutator {
// var_lanes_. Typically, this will be a 1-d index into a flat
// memory space.
Array<PrimExpr> shape = node->buffer->shape;
shape.Set(shape.size() - 1, analyzer_.Simplify(shape[shape.size() - 1] * var_lanes_));
shape.Set(shape.size() - 1,
analyzer_.Simplify(shape[shape.size() - 1] * var_lanes_));
// TODO(Lunderberg): Move this pass to be prior to
// StorageFlatten/FlattenBuffer, implement by appending a
......@@ -146,8 +147,9 @@ class VecAllocAccess : public StmtExprMutator {
// Extend the last index by the number of lanes in the vectorized
// variable.
Array<PrimExpr> indices = node->indices;
indices.Set(indices.size() - 1,
analyzer_.Simplify(indices[indices.size() - 1] * var_lanes_ + var_));
indices.Set(
indices.size() - 1,
analyzer_.Simplify(indices[indices.size() - 1] * var_lanes_ + var_));
auto writer = node.CopyOnWrite();
writer->buffer = buf;
......@@ -156,9 +158,9 @@ class VecAllocAccess : public StmtExprMutator {
}
// buffer var
const VarNode* buf_;
const VarNode *buf_;
// Updated buffer objects.
std::unordered_map<const BufferNode*, Buffer> buffer_map_;
std::unordered_map<const BufferNode *, Buffer> buffer_map_;
// variable to be replaced
Var var_;
// the lanes.
......@@ -170,8 +172,9 @@ class VecAllocAccess : public StmtExprMutator {
// We use ExprFunctor directly instead of StmtExprMutator
// This is because the transformation can change the dtype of the Expr
// The existing ExprMutator transformation rules may not be well defined.
class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExpr&)> {
public:
class Vectorizer : public StmtMutator,
public ExprFunctor<PrimExpr(const PrimExpr &)> {
public:
using ExprFunctor::VisitExpr;
using StmtMutator::operator();
......@@ -179,7 +182,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes);
}
Stmt VisitStmt(const Stmt& stmt) final {
Stmt VisitStmt(const Stmt &stmt) final {
ICHECK(!need_scalarize_);
Stmt ret = StmtMutator::VisitStmt(stmt);
if (need_scalarize_) {
......@@ -190,17 +193,19 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
}
}
PrimExpr VisitExpr(const PrimExpr& e) final { return ExprFunctor::VisitExpr(e); }
PrimExpr VisitExpr(const PrimExpr &e) final {
return ExprFunctor::VisitExpr(e);
}
PrimExpr VisitExpr_(const AddNode* op) final {
PrimExpr VisitExpr_(const AddNode *op) final {
return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a + b; });
}
PrimExpr VisitExpr_(const SubNode* op) final {
PrimExpr VisitExpr_(const SubNode *op) final {
return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a - b; });
}
PrimExpr VisitExpr_(const MulNode* op) final {
PrimExpr VisitExpr_(const MulNode *op) final {
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) {
......@@ -211,11 +216,12 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
if (is_vec_a && is_vec_b) {
// Let's not multiply scalable and fixed length vectors
ICHECK(a.dtype().is_scalable_vector() == b.dtype().is_scalable_vector())
<< "Fixed length and scalable vectors can't be mixed in multiplication.";
<< "Fixed length and scalable vectors can't be mixed in "
"multiplication.";
}
if (is_vec_a || is_vec_b) {
const RampNode* b_ramp = b.as<RampNode>();
const RampNode* a_ramp = a.as<RampNode>();
const RampNode *b_ramp = b.as<RampNode>();
const RampNode *a_ramp = a.as<RampNode>();
if (a_ramp && b.dtype().is_scalar() && analyzer_.CanProve(b > 0)) {
PrimExpr lanes = a_ramp->lanes;
return Ramp(a_ramp->base * b, a_ramp->stride * b, lanes);
......@@ -227,28 +233,34 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
int a_lanes = a.dtype().get_lanes_or_vscale_factor();
int b_lanes = b.dtype().get_lanes_or_vscale_factor();
int max_lanes = std::max(a_lanes, b_lanes);
bool is_scalable = a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector();
return Mul(BroadcastTo(a, max_lanes, is_scalable), BroadcastTo(b, max_lanes, is_scalable));
bool is_scalable =
a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector();
return Mul(BroadcastTo(a, max_lanes, is_scalable),
BroadcastTo(b, max_lanes, is_scalable));
}
}
return BinaryVec<Mul>(op);
}
PrimExpr VisitExpr_(const DivNode* op) final { return BinaryVec<Div>(op); }
PrimExpr VisitExpr_(const ModNode* op) final { return BinaryVec<Mod>(op); }
PrimExpr VisitExpr_(const FloorDivNode* op) final { return BinaryVec<FloorDiv>(op); }
PrimExpr VisitExpr_(const FloorModNode* op) final { return BinaryVec<FloorMod>(op); }
PrimExpr VisitExpr_(const MinNode* op) final { return BinaryVec<Min>(op); }
PrimExpr VisitExpr_(const MaxNode* op) final { return BinaryVec<Max>(op); }
PrimExpr VisitExpr_(const EQNode* op) final { return BinaryVec<EQ>(op); }
PrimExpr VisitExpr_(const NENode* op) final { return BinaryVec<NE>(op); }
PrimExpr VisitExpr_(const LTNode* op) final { return BinaryVec<LT>(op); }
PrimExpr VisitExpr_(const LENode* op) final { return BinaryVec<LE>(op); }
PrimExpr VisitExpr_(const GTNode* op) final { return BinaryVec<GT>(op); }
PrimExpr VisitExpr_(const GENode* op) final { return BinaryVec<GE>(op); }
PrimExpr VisitExpr_(const AndNode* op) final { return BinaryVec<And>(op); }
PrimExpr VisitExpr_(const OrNode* op) final { return BinaryVec<Or>(op); }
PrimExpr VisitExpr_(const NotNode* op) final {
PrimExpr VisitExpr_(const DivNode *op) final { return BinaryVec<Div>(op); }
PrimExpr VisitExpr_(const ModNode *op) final { return BinaryVec<Mod>(op); }
PrimExpr VisitExpr_(const FloorDivNode *op) final {
return BinaryVec<FloorDiv>(op);
}
PrimExpr VisitExpr_(const FloorModNode *op) final {
return BinaryVec<FloorMod>(op);
}
PrimExpr VisitExpr_(const MinNode *op) final { return BinaryVec<Min>(op); }
PrimExpr VisitExpr_(const MaxNode *op) final { return BinaryVec<Max>(op); }
PrimExpr VisitExpr_(const EQNode *op) final { return BinaryVec<EQ>(op); }
PrimExpr VisitExpr_(const NENode *op) final { return BinaryVec<NE>(op); }
PrimExpr VisitExpr_(const LTNode *op) final { return BinaryVec<LT>(op); }
PrimExpr VisitExpr_(const LENode *op) final { return BinaryVec<LE>(op); }
PrimExpr VisitExpr_(const GTNode *op) final { return BinaryVec<GT>(op); }
PrimExpr VisitExpr_(const GENode *op) final { return BinaryVec<GE>(op); }
PrimExpr VisitExpr_(const AndNode *op) final { return BinaryVec<And>(op); }
PrimExpr VisitExpr_(const OrNode *op) final { return BinaryVec<Or>(op); }
PrimExpr VisitExpr_(const NotNode *op) final {
PrimExpr a = this->VisitExpr(op->a);
if (a.same_as(op->a)) {
return GetRef<PrimExpr>(op);
......@@ -257,7 +269,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
}
}
PrimExpr VisitExpr_(const RampNode* op) final {
PrimExpr VisitExpr_(const RampNode *op) final {
PrimExpr base = this->VisitExpr(op->base);
PrimExpr stride = this->VisitExpr(op->stride);
ICHECK(!base.dtype().is_scalable_vector())
......@@ -267,11 +279,13 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
if (base.dtype().is_fixed_length_vector() && stride.dtype().is_scalar()) {
ICHECK(op->lanes->IsInstance<IntImmNode>())
<< "Vectorizing over existing scalable vectors is not supported.";
const RampNode* base_ramp = base.as<RampNode>();
const RampNode *base_ramp = base.as<RampNode>();
int op_lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
int base_ramp_lanes = static_cast<int>(Downcast<IntImm>(base_ramp->lanes)->value);
int base_ramp_lanes =
static_cast<int>(Downcast<IntImm>(base_ramp->lanes)->value);
if (analyzer_.CanProve(base_ramp->stride ==
stride * make_const(stride.dtype(), base_ramp_lanes))) {
stride *
make_const(stride.dtype(), base_ramp_lanes))) {
return Ramp(base_ramp->base, stride, op_lanes * base_ramp_lanes);
}
}
......@@ -280,13 +294,13 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
stride = BroadcastTo(stride, lanes, false);
Array<PrimExpr> elems;
for (int i = 0; i < lanes; ++i) {
elems.push_back(
Ramp(Shuffle::ExtractElement(base, i), Shuffle::ExtractElement(stride, i), op->lanes));
elems.push_back(Ramp(Shuffle::ExtractElement(base, i),
Shuffle::ExtractElement(stride, i), op->lanes));
}
return Shuffle::Concat(elems);
}
PrimExpr VisitExpr_(const BroadcastNode* op) final {
PrimExpr VisitExpr_(const BroadcastNode *op) final {
PrimExpr value = this->VisitExpr(op->value);
if (value.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true;
......@@ -299,45 +313,56 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
}
}
PrimExpr VisitExpr_(const SelectNode* op) final {
PrimExpr VisitExpr_(const SelectNode *op) final {
PrimExpr cond = this->VisitExpr(op->condition);
PrimExpr t = this->VisitExpr(op->true_value);
PrimExpr f = this->VisitExpr(op->false_value);
if (cond.same_as(op->condition) && t.same_as(op->true_value) && f.same_as(op->false_value)) {
if (cond.same_as(op->condition) && t.same_as(op->true_value) &&
f.same_as(op->false_value)) {
return GetRef<PrimExpr>(op);
} else {
int cond_lanes = cond.dtype().get_lanes_or_vscale_factor();
int t_lanes = t.dtype().get_lanes_or_vscale_factor();
int f_lanes = f.dtype().get_lanes_or_vscale_factor();
int lanes = std::max(std::max(cond_lanes, t_lanes), f_lanes);
bool is_scalable = cond.dtype().is_scalable_vector() || t.dtype().is_scalable_vector() ||
bool is_scalable = cond.dtype().is_scalable_vector() ||
t.dtype().is_scalable_vector() ||
f.dtype().is_scalable_vector();
return Select(BroadcastTo(cond, lanes, is_scalable), BroadcastTo(t, lanes, is_scalable),
return Select(BroadcastTo(cond, lanes, is_scalable),
BroadcastTo(t, lanes, is_scalable),
BroadcastTo(f, lanes, is_scalable));
}
}
PrimExpr VisitExpr_(const CastNode* op) final {
PrimExpr VisitExpr_(const CastNode *op) final {
PrimExpr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) {
return GetRef<PrimExpr>(op);
} else {
if (value.dtype().is_scalable_vector()) {
return Cast(op->dtype.with_scalable_vscale_factor(value.dtype().vscale_factor()), value);
return Cast(op->dtype.with_scalable_vscale_factor(
value.dtype().vscale_factor()),
value);
} else {
return Cast(op->dtype.with_lanes(value.dtype().lanes()), value);
}
}
}
PrimExpr VisitExpr_(const FloatImmNode* op) final { return GetRef<PrimExpr>(op); }
PrimExpr VisitExpr_(const FloatImmNode *op) final {
return GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const IntImmNode* op) final { return GetRef<PrimExpr>(op); }
PrimExpr VisitExpr_(const IntImmNode *op) final {
return GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const StringImmNode* op) final { return GetRef<PrimExpr>(op); }
PrimExpr VisitExpr_(const StringImmNode *op) final {
return GetRef<PrimExpr>(op);
}
// Variable
PrimExpr VisitExpr_(const VarNode* op) final {
PrimExpr VisitExpr_(const VarNode *op) final {
Var var = GetRef<Var>(op);
if (var.same_as(var_)) {
......@@ -351,7 +376,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
}
}
// IfThenElse expr
PrimExpr MutateIfThenElseExpr_(const CallNode* op) {
PrimExpr MutateIfThenElseExpr_(const CallNode *op) {
PrimExpr cond = this->VisitExpr(op->args[0]);
if (cond.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true;
......@@ -359,24 +384,27 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
}
PrimExpr t = this->VisitExpr(op->args[1]);
PrimExpr f = this->VisitExpr(op->args[2]);
if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && f.same_as(op->args[2])) {
if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) &&
f.same_as(op->args[2])) {
return GetRef<PrimExpr>(op);
} else {
int t_lanes = t.dtype().get_lanes_or_vscale_factor();
int f_lanes = f.dtype().get_lanes_or_vscale_factor();
int lanes = std::max(t_lanes, f_lanes);
bool is_scalable = t.dtype().is_scalable_vector() || f.dtype().is_scalable_vector();
bool is_scalable =
t.dtype().is_scalable_vector() || f.dtype().is_scalable_vector();
t = BroadcastTo(t, lanes, is_scalable);
f = BroadcastTo(f, lanes, is_scalable);
if (is_scalable) {
return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, {cond, t, f});
return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op,
{cond, t, f});
} else {
return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f});
}
}
}
// Reinterpret expr
PrimExpr MutateReinterpretExpr_(const CallNode* op) {
PrimExpr MutateReinterpretExpr_(const CallNode *op) {
ICHECK(op->op.same_as(builtin::reinterpret()));
PrimExpr value = this->VisitExpr(op->args[0]);
if (value.same_as(op->args[0])) {
......@@ -384,14 +412,15 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
} else {
int lanes = value.dtype().get_lanes_or_vscale_factor();
if (value.dtype().is_scalable_vector()) {
return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, {value});
return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op,
{value});
} else {
return Call(op->dtype.with_lanes(lanes), op->op, {value});
}
}
}
// Call
PrimExpr VisitExpr_(const CallNode* op) final {
PrimExpr VisitExpr_(const CallNode *op) final {
if (op->op.same_as(builtin::if_then_else())) {
return MutateIfThenElseExpr_(op);
} else if (op->op.same_as(builtin::texture2d_load())) {
......@@ -406,13 +435,15 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
// Vectorize the value to store
Array<PrimExpr> value{op->args.back()};
Array<PrimExpr> mutated_value = MutateArray(value, &lane);
Array<PrimExpr> new_args{op->args[0], op->args[1], op->args[2], mutated_value[0]};
Array<PrimExpr> new_args{op->args[0], op->args[1], op->args[2],
mutated_value[0]};
return Call(op->dtype.with_lanes(lane), op->op, new_args);
} else if (op->op.same_as(builtin::reinterpret())) {
return MutateReinterpretExpr_(op);
}
auto optional_op = op->op.as<Op>();
bool vectorizable = optional_op && op_vectorizable_.get(optional_op.value(), false) &&
bool vectorizable = optional_op &&
op_vectorizable_.get(optional_op.value(), false) &&
!op->dtype.is_scalable_vector();
if (!vectorizable) {
......@@ -443,10 +474,12 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
}
}
// BufferLoad
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = GetRef<BufferLoad>(op);
auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); };
auto fmutate = [this](const PrimExpr &index) {
return this->VisitExpr(index);
};
Array<PrimExpr> indices = op->indices.Map(fmutate);
if (!indices.same_as(op->indices)) {
......@@ -457,7 +490,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
return std::move(load);
}
// Let
PrimExpr VisitExpr_(const LetNode* op) final {
PrimExpr VisitExpr_(const LetNode *op) final {
PrimExpr value = this->VisitExpr(op->value);
// Weaker SSA condition
// A single var can be binded in multiple lets
......@@ -486,24 +519,28 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
}
}
// BufferStore
Stmt VisitStmt_(const BufferStoreNode* op) final {
Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store = GetRef<BufferStore>(op);
auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); };
auto fmutate = [this](const PrimExpr &index) {
return this->VisitExpr(index);
};
Array<PrimExpr> indices = op->indices.Map(fmutate);
PrimExpr value = this->VisitExpr(op->value);
if (!indices.same_as(op->indices) || !value.same_as(op->value)) {
ICHECK(!op->buffer->dtype.is_scalable_vector())
<< "Vectorizing over scalable buffer elements is not supported in vectorizer.";
<< "Vectorizing over scalable buffer elements is not supported in "
"vectorizer.";
// How many lanes of indexing are present in the index and
// buffer element type, excluding the last index.
int other_index_lanes = op->buffer->dtype.lanes();
for (size_t i = 0; i < indices.size() - 1; i++) {
other_index_lanes *= indices[i].dtype().lanes();
// Only allow the last index to be scalable
ICHECK(!indices[i].dtype().is_scalable_vector()) << "Only the last index can be scalable.";
ICHECK(!indices[i].dtype().is_scalable_vector())
<< "Only the last index can be scalable.";
}
// The total number of lanes of indexing, including the last index.
......@@ -519,14 +556,16 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
int total_lanes = std::max(index_lanes, value_dtype_lanes);
ICHECK_EQ(total_lanes % other_index_lanes, 0)
<< "When storing to buffer " << op->buffer->name << ", cannot produce " << total_lanes
<< "When storing to buffer " << op->buffer->name
<< ", cannot produce " << total_lanes
<< " lanes of storage location by changing the last index.";
int last_index_lanes = total_lanes / other_index_lanes;
// Broadcast the last index such that the total number of index
// lanes matches the desired number.
indices.Set(indices.size() - 1, BroadcastTo(indices[indices.size() - 1], last_index_lanes,
is_last_index_scalable));
indices.Set(indices.size() - 1,
BroadcastTo(indices[indices.size() - 1], last_index_lanes,
is_last_index_scalable));
auto writer = store.CopyOnWrite();
writer->indices = indices;
......@@ -536,7 +575,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
return std::move(store);
}
// For
Stmt VisitStmt_(const ForNode* op) final {
Stmt VisitStmt_(const ForNode *op) final {
if (op->kind == ForKind::kVectorized) {
LOG(WARNING) << "Detect vectorize inside vectorized loop, ignoring...";
}
......@@ -550,12 +589,12 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
if (extent.same_as(op->extent) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
return For(op->loop_var, op->min, extent, op->kind, body, op->thread_binding,
op->annotations);
return For(op->loop_var, op->min, extent, op->kind, body,
op->thread_binding, op->annotations);
}
}
// IfThenElse
Stmt VisitStmt_(const IfThenElseNode* op) final {
Stmt VisitStmt_(const IfThenElseNode *op) final {
ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector());
PrimExpr condition = this->VisitExpr(op->condition);
if (condition.dtype().is_scalable_or_fixed_length_vector()) {
......@@ -574,13 +613,14 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
}
}
// While
Stmt VisitStmt_(const WhileNode* op) final {
Stmt VisitStmt_(const WhileNode *op) final {
LOG(FATAL) << "A while loop inside a vectorized loop not supported.";
}
// LetStmt
Stmt VisitStmt_(const LetStmtNode* op) final {
Stmt VisitStmt_(const LetStmtNode *op) final {
PrimExpr value = this->VisitExpr(op->value);
ICHECK(!let_binding_.count(op->var)) << "SSA violation, a single var is binded twice";
ICHECK(!let_binding_.count(op->var))
<< "SSA violation, a single var is binded twice";
let_binding_[op->var] = value;
if (value.dtype().get_lanes_or_vscale_factor() !=
......@@ -599,20 +639,22 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
}
}
// Allocate
Stmt VisitStmt_(const AllocateNode* op) final {
Stmt VisitStmt_(const AllocateNode *op) final {
// Mutate the condition
PrimExpr condition = this->VisitExpr(op->condition);
if (condition.dtype().is_scalable_or_fixed_length_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc of " << op->buffer_var->name_hint;
LOG(WARNING) << "Cannot handle vector extent in alloc of "
<< op->buffer_var->name_hint;
return Scalarize(GetRef<Stmt>(op));
}
// Mutate the extents
Array<PrimExpr> extents;
for (const auto& extent : op->extents) {
for (const auto &extent : op->extents) {
PrimExpr new_ext = this->VisitExpr(extent);
if (new_ext.dtype().is_scalable_or_fixed_length_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc of " << op->buffer_var->name_hint;
LOG(WARNING) << "Cannot handle vector extent in alloc of "
<< op->buffer_var->name_hint;
return Scalarize(GetRef<Stmt>(op));
}
extents.push_back(new_ext);
......@@ -629,7 +671,8 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
extents.Set(extents.size() - 1, extents[extents.size() - 1] * var_lanes_);
// Rewrite access to the buffer in the body.
Stmt body = VecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body);
Stmt body =
VecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body);
body = this->VisitStmt(body);
return Allocate(op->buffer_var, op->dtype, extents, condition, body);
}
......@@ -641,11 +684,11 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
return For(idx, IntImm(var_->dtype, 0), var_lanes_, ForKind::kSerial, stmt);
}
// ProducerStore
Stmt VisitStmt_(const ProducerStoreNode* op) final {
Stmt VisitStmt_(const ProducerStoreNode *op) final {
LOG(FATAL) << "ProducerProvide cannot appear in a TIR PrimFunc";
}
private:
private:
// analyzer
arith::Analyzer analyzer_;
// deep equal
......@@ -661,19 +704,22 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
// Let binding
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> let_binding_;
// vectorizable property
OpAttrMap<TVectorizable> op_vectorizable_ = Op::GetAttrMap<TVectorizable>("TVectorizable");
OpAttrMap<TVectorizable> op_vectorizable_ =
Op::GetAttrMap<TVectorizable>("TVectorizable");
// mutate array, with given lane requirement
// when finished, p_lane updates the lane requirement.
Array<PrimExpr> MutateArray(Array<PrimExpr> arr, int* p_lanes) {
if (arr.size() == 0) return arr;
int& lanes = *p_lanes;
Array<PrimExpr> MutateArray(Array<PrimExpr> arr, int *p_lanes) {
if (arr.size() == 0)
return arr;
int &lanes = *p_lanes;
bool changed = false;
std::vector<PrimExpr> new_arr(arr.size());
for (size_t i = 0; i < arr.size(); i++) {
PrimExpr old_elem = arr[i];
PrimExpr new_elem = this->VisitExpr(old_elem);
if (!new_elem.same_as(old_elem)) changed = true;
if (!new_elem.same_as(old_elem))
changed = true;
new_arr[i] = new_elem;
lanes = std::max(lanes, new_elem.dtype().lanes());
}
......@@ -684,12 +730,13 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
changed = true;
}
}
if (!changed) return arr;
if (!changed)
return arr;
return Array<PrimExpr>(new_arr);
}
template <typename TOp, typename T>
PrimExpr BinaryVec(const T* op) {
static_assert(std::is_same<typename TOp::ContainerType, T>::value, "constraint");
template <typename TOp, typename T> PrimExpr BinaryVec(const T *op) {
static_assert(std::is_same<typename TOp::ContainerType, T>::value,
"constraint");
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) {
......@@ -698,12 +745,14 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
int a_lanes = a.dtype().get_lanes_or_vscale_factor();
int b_lanes = b.dtype().get_lanes_or_vscale_factor();
int lanes = std::max(a_lanes, b_lanes);
bool is_scalable = a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector();
return TOp(BroadcastTo(a, lanes, is_scalable), BroadcastTo(b, lanes, is_scalable));
bool is_scalable =
a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector();
return TOp(BroadcastTo(a, lanes, is_scalable),
BroadcastTo(b, lanes, is_scalable));
}
}
template <typename T, typename FCompute>
PrimExpr AddSubVec(const T* op, FCompute fcompute) {
PrimExpr AddSubVec(const T *op, FCompute fcompute) {
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) {
......@@ -713,21 +762,25 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
int b_lanes = b.dtype().get_lanes_or_vscale_factor();
int lanes = std::max(a_lanes, b_lanes);
if (lanes != 1) {
const RampNode* b_ramp = b.as<RampNode>();
const RampNode* a_ramp = a.as<RampNode>();
const RampNode *b_ramp = b.as<RampNode>();
const RampNode *a_ramp = a.as<RampNode>();
if (a.dtype().is_scalar() && b_ramp) {
return Ramp(fcompute(a, b_ramp->base),
fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride), b_ramp->lanes);
return Ramp(
fcompute(a, b_ramp->base),
fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride),
b_ramp->lanes);
}
if (b.dtype().is_scalar() && a_ramp) {
return Ramp(fcompute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
}
}
bool is_scalable = a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector();
return fcompute(BroadcastTo(a, lanes, is_scalable), BroadcastTo(b, lanes, is_scalable));
bool is_scalable =
a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector();
return fcompute(BroadcastTo(a, lanes, is_scalable),
BroadcastTo(b, lanes, is_scalable));
}
}
};
} // namespace tl
} // namespace tvm
\ No newline at end of file
} // namespace tl
} // namespace tvm
\ No newline at end of file
......@@ -34,19 +34,19 @@ namespace tl {
using namespace tir;
class FrontendLegalizer : public arith::IRMutatorWithAnalyzer {
public:
public:
static PrimFunc Substitute(PrimFunc f) {
arith::Analyzer analyzer;
FrontendLegalizer substituter(&analyzer);
PrimFuncNode* fptr = f.CopyOnWrite();
PrimFuncNode *fptr = f.CopyOnWrite();
fptr->body = substituter.VisitStmt(f->body);
return f;
}
private:
private:
using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer;
Stmt VisitStmt_(const ForNode* node) final {
Stmt VisitStmt_(const ForNode *node) final {
if (node->kind == ForKind::kParallel) {
parallel_for_scope_++;
}
......@@ -57,7 +57,7 @@ class FrontendLegalizer : public arith::IRMutatorWithAnalyzer {
return n;
}
PrimExpr VisitExpr_(const VarNode* node) final {
PrimExpr VisitExpr_(const VarNode *node) final {
if (let_bindings_.count(node)) {
return arith::IRMutatorWithAnalyzer::VisitExpr(let_bindings_[node]);
} else {
......@@ -65,18 +65,18 @@ class FrontendLegalizer : public arith::IRMutatorWithAnalyzer {
}
}
Stmt VisitStmt_(const LetStmtNode* node) final {
Stmt VisitStmt_(const LetStmtNode *node) final {
let_bindings_[node->var.get()] = node->value;
return arith::IRMutatorWithAnalyzer::VisitStmt(node->body);
}
PrimExpr VisitExpr_(const LetNode* node) final {
PrimExpr VisitExpr_(const LetNode *node) final {
let_bindings_[node->var.get()] = node->value;
return arith::IRMutatorWithAnalyzer::VisitExpr(node->body);
}
int parallel_for_scope_ = 0;
std::unordered_map<const VarNode*, PrimExpr> let_bindings_;
std::unordered_map<const VarNode *, PrimExpr> let_bindings_;
};
using namespace tir::transform;
......@@ -91,5 +91,5 @@ Pass FrontendLegalize() {
TVM_REGISTER_GLOBAL("tl.transform.FrontendLegalize")
.set_body_typed(FrontendLegalize);
} // namespace tl
} // namespace tvm
} // namespace tl
} // namespace tvm
......@@ -38,10 +38,10 @@ using namespace tir;
enum class Proxy { kGeneric, kAsync, kBoth };
class ProxyMarker : public StmtVisitor {
public:
public:
ProxyMarker() = default;
Proxy GetProxy(const StmtNode* stmt) const {
Proxy GetProxy(const StmtNode *stmt) const {
auto it = map_.find(stmt);
// ICHECK(it != map_.end());
// TODO: This is a hack implementation to avoid the ICHECK failure.
......@@ -51,9 +51,9 @@ class ProxyMarker : public StmtVisitor {
return it->second;
}
Proxy GetProxy(const Stmt& stmt) const { return GetProxy(stmt.get()); }
Proxy GetProxy(const Stmt &stmt) const { return GetProxy(stmt.get()); }
void VisitStmt_(const EvaluateNode* op) final {
void VisitStmt_(const EvaluateNode *op) final {
Proxy proxy = Proxy::kAsync;
if (auto call = op->value.as<CallNode>()) {
if (call->op.same_as(LDMatrixOp()) || call->op.same_as(STMatrixOp())) {
......@@ -63,12 +63,12 @@ class ProxyMarker : public StmtVisitor {
SetProxy(op, proxy);
}
void VisitStmt_(const BufferStoreNode* op) final {
void VisitStmt_(const BufferStoreNode *op) final {
Proxy proxy = Proxy::kGeneric;
SetProxy(op, proxy);
}
void VisitStmt_(const SeqStmtNode* op) final {
void VisitStmt_(const SeqStmtNode *op) final {
StmtVisitor::VisitStmt_(op);
auto role = GetProxy(op->seq[0]);
for (auto stmt : op->seq) {
......@@ -80,61 +80,59 @@ class ProxyMarker : public StmtVisitor {
SetProxy(op, role);
}
void VisitStmt_(const IfThenElseNode* op) final {
void VisitStmt_(const IfThenElseNode *op) final {
StmtVisitor::VisitStmt_(op);
auto role = GetProxy(op->then_case);
if (op->else_case.defined()) {
auto role_else = GetProxy(op->else_case.value());
if (role != role_else) role = Proxy::kBoth;
if (role != role_else)
role = Proxy::kBoth;
}
SetProxy(op, role);
}
void VisitStmt_(const BlockRealizeNode* op) final {
void VisitStmt_(const BlockRealizeNode *op) final {
StmtVisitor::VisitStmt_(op);
SetProxy(op, GetProxy(op->block));
}
template <class NodeType>
void HandleBodyStmt(const NodeType* op) {
template <class NodeType> void HandleBodyStmt(const NodeType *op) {
StmtVisitor::VisitStmt_(op);
SetProxy(op, GetProxy(op->body));
}
void VisitStmt_(const ForNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const LetStmtNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const AttrStmtNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const AssertStmtNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const BlockNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const ForNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const LetStmtNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const AttrStmtNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const AssertStmtNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const BlockNode *op) final { HandleBodyStmt(op); }
private:
void SetProxy(const StmtNode* stmt, Proxy proxy) { map_[stmt] = proxy; }
std::unordered_map<const StmtNode*, Proxy> map_;
private:
void SetProxy(const StmtNode *stmt, Proxy proxy) { map_[stmt] = proxy; }
std::unordered_map<const StmtNode *, Proxy> map_;
};
class InjectFenceProxy : public StmtExprMutator {
public:
public:
static PrimFunc Substitute(PrimFunc f) {
auto T = InjectFenceProxy();
f.CopyOnWrite()->body = T(f->body);
return f;
}
private:
Proxy get_generic_proxy(const Stmt& stmt) {
private:
Proxy get_generic_proxy(const Stmt &stmt) {
auto marker = ProxyMarker();
marker(stmt);
return marker.GetProxy(stmt);
}
Stmt VisitStmt_(const SeqStmtNode* op) final {
Stmt VisitStmt_(const SeqStmtNode *op) final {
ICHECK(op->seq.size() > 0);
Array<Stmt> new_body;
Proxy cur_proxy, prev_proxy;
auto fence_stmt = Evaluate(Call(DataType::Handle(), FenceProxyAsyncOp(), {}));
auto fence_stmt =
Evaluate(Call(DataType::Handle(), FenceProxyAsyncOp(), {}));
prev_proxy = get_generic_proxy(op->seq[0]);
new_body.push_back(VisitStmt(op->seq[0]));
if (op->seq.size() > 1) {
......@@ -171,5 +169,5 @@ tvm::transform::Pass InjectFenceProxy() {
TVM_REGISTER_GLOBAL("tl.transform.InjectFenceProxy")
.set_body_typed(InjectFenceProxy);
} // namespace tl
} // namespace tvm
} // namespace tl
} // namespace tvm
......@@ -19,7 +19,8 @@
/*!
* \file inject_software_pipeline.cc
* \brief Transform annotated loops into pipelined one that parallelize producers and consumers
* \brief Transform annotated loops into pipelined one that parallelize
* producers and consumers
*/
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
......@@ -38,24 +39,27 @@ using namespace tir;
/*!
* \brief Create a block and infer the access region with the given body.
*
* The result is a opaque block that doesn't contain any block iter vars. In case the body is a
* block realize without predicate, it is unnecessary to create a new block, the block of the block
* realize will be returned.
* The result is a opaque block that doesn't contain any block iter vars. In
* case the body is a block realize without predicate, it is unnecessary to
* create a new block, the block of the block realize will be returned.
*
* \param body The body of the block.
* \param buffer_data_to_buffer The map from buffer data to buffer.
* \return The result block.
*/
Block MakeBlock(const Stmt& body, const Map<Var, Buffer>& buffer_data_to_buffer) {
if (const BlockRealizeNode* block_realize = body.as<BlockRealizeNode>()) {
Block MakeBlock(const Stmt &body,
const Map<Var, Buffer> &buffer_data_to_buffer) {
if (const BlockRealizeNode *block_realize = body.as<BlockRealizeNode>()) {
if (is_one(block_realize->predicate)) {
// no need to create a new block
return block_realize->block;
}
}
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ body);
Array<Array<BufferRegion>> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer);
BlockNode* n = block.CopyOnWrite();
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
/*body*/ body);
Array<Array<BufferRegion>> access =
GetBlockReadWriteRegion(block, buffer_data_to_buffer);
BlockNode *n = block.CopyOnWrite();
n->reads = access[0];
n->writes = access[1];
return block;
......@@ -68,69 +72,76 @@ struct PipelineAnnotation {
bool async;
};
using PipelineInfo = std::unordered_map<Block, PipelineAnnotation, ObjectPtrHash, ObjectPtrEqual>;
using PipelineInfo = std::unordered_map<Block, PipelineAnnotation,
ObjectPtrHash, ObjectPtrEqual>;
struct BufferAccessInfo {
int def = -1; // the defining stage of the buffer
int use = -1; // the last using stage of the buffer
int def = -1; // the defining stage of the buffer
int use = -1; // the last using stage of the buffer
};
/*!
* \brief Rewriter for the body of the software pipeline. This pass inserts `floormod` to indices
* of the remapped buffer to select the version corresponding to the pipeline stage.
* \brief Rewriter for the body of the software pipeline. This pass inserts
* `floormod` to indices of the remapped buffer to select the version
* corresponding to the pipeline stage.
*/
class PipelineBodyRewriter : public StmtExprMutator {
public:
public:
/*!
* \brief Constructor of PipelineBodyRewriter.
* \param buffer_data_to_buffer The map from buffer data to buffer.
* \param buffer_remap The map from original buffer to the buffer with updated shape for
* multi-versioning in the software pipeline.
* \param pipeline_loop The original loop to be software pipelined.
* \param access_all_versions Whether all versions the buffers in the software pipeline are
* accessed. This will be used to update block access region. In the prologue and epilogue
* of a two-stage software pipeline, only one version of these buffers are accessed.
* \param buffer_remap The map from original buffer to the buffer with updated
* shape for multi-versioning in the software pipeline. \param pipeline_loop
* The original loop to be software pipelined. \param access_all_versions
* Whether all versions the buffers in the software pipeline are accessed.
* This will be used to update block access region. In the prologue and
* epilogue of a two-stage software pipeline, only one version of these
* buffers are accessed.
*/
PipelineBodyRewriter(const Map<Var, Buffer>& buffer_data_to_buffer,
const Map<Buffer, Buffer>& buffer_remap, For pipeline_loop,
bool access_all_versions)
PipelineBodyRewriter(const Map<Var, Buffer> &buffer_data_to_buffer,
const Map<Buffer, Buffer> &buffer_remap,
For pipeline_loop, bool access_all_versions)
: buffer_data_to_buffer_(buffer_data_to_buffer),
buffer_remap_(buffer_remap),
pipeline_loop_(pipeline_loop),
buffer_remap_(buffer_remap), pipeline_loop_(pipeline_loop),
access_all_versions_(access_all_versions) {}
private:
BufferRegion RewritePipelineBufferRegion(const BufferRegion& buffer_region) const {
private:
BufferRegion
RewritePipelineBufferRegion(const BufferRegion &buffer_region) const {
auto it = buffer_remap_.find(buffer_region->buffer);
if (it != buffer_remap_.end()) {
Region new_region = buffer_region->region;
const Buffer& new_buffer = (*it).second;
// For pipeline buffers, relax the access region of the first dimension to full extent
// if access_all_versions == true
const Buffer &new_buffer = (*it).second;
// For pipeline buffers, relax the access region of the first dimension to
// full extent if access_all_versions == true
Range accessed_version =
access_all_versions_
? Range::FromMinExtent(0, new_buffer->shape[0])
: Range::FromMinExtent(floormod((pipeline_loop_->loop_var - pipeline_loop_->min),
new_buffer->shape[0]),
Integer(1));
: Range::FromMinExtent(
floormod((pipeline_loop_->loop_var - pipeline_loop_->min),
new_buffer->shape[0]),
Integer(1));
new_region.insert(new_region.begin(), accessed_version);
return BufferRegion(new_buffer, new_region);
}
return buffer_region;
}
PrimExpr RewriteBufferAccess(const Call& call, const std::vector<int> arg_indices) {
auto product = [](const Array<PrimExpr>& input) {
return foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
make_const(DataType::Int(32), 1), input);
PrimExpr RewriteBufferAccess(const Call &call,
const std::vector<int> arg_indices) {
auto product = [](const Array<PrimExpr> &input) {
return foldl(
[](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
make_const(DataType::Int(32), 1), input);
};
Array<PrimExpr> new_args = call->args;
for (int i : arg_indices) {
const Buffer& buffer = buffer_data_to_buffer_.at(Downcast<Var>(call->args[i]));
const Buffer &buffer =
buffer_data_to_buffer_.at(Downcast<Var>(call->args[i]));
auto it = buffer_remap_.find(buffer);
if (it != buffer_remap_.end()) {
const Buffer& new_buffer = (*it).second;
const PrimExpr& old_index = call->args[i + 1];
const Buffer &new_buffer = (*it).second;
const PrimExpr &old_index = call->args[i + 1];
PrimExpr offset;
if (new_buffer->strides.empty()) {
offset = product(buffer->shape);
......@@ -138,62 +149,63 @@ class PipelineBodyRewriter : public StmtExprMutator {
offset = new_buffer->strides[0];
}
PrimExpr new_index =
old_index + floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset;
old_index +
floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset;
new_args.Set(i + 1, new_index);
}
}
return Call(call->dtype, call->op, new_args, call->span);
}
Stmt VisitStmt_(const BlockNode* op) final {
for (const Buffer& alloc_buffer : op->alloc_buffers) {
Stmt VisitStmt_(const BlockNode *op) final {
for (const Buffer &alloc_buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer);
}
Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
BlockNode* n = block.CopyOnWrite();
n->reads.MutateByApply([this](const BufferRegion& buffer_region) {
BlockNode *n = block.CopyOnWrite();
n->reads.MutateByApply([this](const BufferRegion &buffer_region) {
return RewritePipelineBufferRegion(buffer_region);
});
n->writes.MutateByApply([this](const BufferRegion& buffer_region) {
n->writes.MutateByApply([this](const BufferRegion &buffer_region) {
return RewritePipelineBufferRegion(buffer_region);
});
for (const Buffer& alloc_buffer : op->alloc_buffers) {
for (const Buffer &alloc_buffer : op->alloc_buffers) {
buffer_data_to_buffer_.erase(alloc_buffer->data);
}
return std::move(block);
}
Stmt VisitStmt_(const BufferStoreNode* op) final {
Stmt VisitStmt_(const BufferStoreNode *op) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
auto it = buffer_remap_.find(store->buffer);
if (it == buffer_remap_.end()) {
return std::move(store);
}
const Buffer& new_buffer = (*it).second;
auto* n = store.CopyOnWrite();
const Buffer &new_buffer = (*it).second;
auto *n = store.CopyOnWrite();
n->buffer = new_buffer;
PrimExpr version =
floormod((pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
PrimExpr version = floormod(
(pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
n->indices.insert(n->indices.begin(), version);
return std::move(store);
}
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
auto it = buffer_remap_.find(load->buffer);
if (it == buffer_remap_.end()) {
return std::move(load);
}
const Buffer& new_buffer = (*it).second;
auto* n = load.CopyOnWrite();
const Buffer &new_buffer = (*it).second;
auto *n = load.CopyOnWrite();
n->buffer = new_buffer;
PrimExpr version =
floormod((pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
PrimExpr version = floormod(
(pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
n->indices.insert(n->indices.begin(), version);
return std::move(load);
}
PrimExpr VisitExpr_(const CallNode* op) final {
PrimExpr VisitExpr_(const CallNode *op) final {
Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
if (call->op.same_as(builtin::tvm_access_ptr())) {
return RewriteBufferAccess(call, {1});
......@@ -208,24 +220,25 @@ class PipelineBodyRewriter : public StmtExprMutator {
};
/*!
* \brief Rewriter for the software pipeline that rewrite a loop into a pipelined one.
* \brief Rewriter for the software pipeline that rewrite a loop into a
* pipelined one.
*/
class PipelineRewriter : public StmtExprMutator {
public:
PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer, const Array<Buffer>& pipeline_allocs,
const For& pipeline_loop, const PipelineInfo& pipeline_info)
public:
PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer,
const Array<Buffer> &pipeline_allocs,
const For &pipeline_loop, const PipelineInfo &pipeline_info)
: buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
pipeline_allocs_(pipeline_allocs),
pipeline_loop_(pipeline_loop),
pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop),
pipeline_info_(pipeline_info) {}
Stmt BuildPipeline() {
// Step 1: Analyze accesses to the buffers in the pipeline and compute the number of versions
// need to maintain for each buffer.
std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual> infos =
GetBufferAccessInfo();
for (const Buffer& buffer : pipeline_allocs_) {
// Step 1: Analyze accesses to the buffers in the pipeline and compute the
// number of versions need to maintain for each buffer.
std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
infos = GetBufferAccessInfo();
for (const Buffer &buffer : pipeline_allocs_) {
int num_versions = ComputeBufferVersions(buffer, infos.at(buffer));
if (num_versions > 1) {
buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions));
......@@ -233,27 +246,28 @@ class PipelineRewriter : public StmtExprMutator {
}
ordered_stmts_.resize(pipeline_info_.size());
for (const auto& [block, anno] : pipeline_info_) {
for (const auto &[block, anno] : pipeline_info_) {
ordered_stmts_.Set(anno.order, block);
}
for (const Block& block : ordered_stmts_) {
for (const Block &block : ordered_stmts_) {
int stage = pipeline_info_[block].stage;
if (pipeline_info_[block].async) {
auto& state = async_states[stage];
auto &state = async_states[stage];
state.producer_head = pipeline_loop_->min - 1;
for (auto write_region : block->writes) {
auto buffer = write_region->buffer;
state.dst_buffers.insert(buffer.get());
if (buffer_remap_.count(buffer)) state.dst_buffers.insert(buffer_remap_[buffer].get());
if (buffer_remap_.count(buffer))
state.dst_buffers.insert(buffer_remap_[buffer].get());
}
}
}
std::unordered_set<int> consumed;
for (const Block& block : ordered_stmts_) {
for (const Block &block : ordered_stmts_) {
int stage = pipeline_info_[block].stage;
if (pipeline_info_[block].async) {
auto& state = async_states[stage];
auto &state = async_states[stage];
if (state.commit_groups.empty() || consumed.count(stage)) {
state.commit_groups.push_back({});
}
......@@ -263,13 +277,15 @@ class PipelineRewriter : public StmtExprMutator {
auto buffer = buffer_remap_.count(write_region->buffer)
? buffer_remap_[write_region->buffer]
: write_region->buffer;
state.buffer_to_commit_group_[buffer.get()] = state.commit_groups.size() - 1;
state.buffer_to_commit_group_[buffer.get()] =
state.commit_groups.size() - 1;
}
}
for (auto read_region : block->reads) {
for (const auto& [producer_stage_id, producer_state] : async_states) {
if (producer_stage_id <= stage && producer_state.writes(read_region->buffer)) {
for (const auto &[producer_stage_id, producer_state] : async_states) {
if (producer_stage_id <= stage &&
producer_state.writes(read_region->buffer)) {
consumed.insert(producer_stage_id);
}
}
......@@ -277,17 +293,21 @@ class PipelineRewriter : public StmtExprMutator {
}
// Step 2: Emit the pipeline prologue, body and epilogue.
Stmt prologue = EmitImpl(pipeline_loop_->min, pipeline_loop_->min + max_stage_, true, true);
Stmt body = EmitImpl(pipeline_loop_->min + max_stage_,
pipeline_loop_->min + pipeline_loop_->extent, false, false);
Stmt epilogue = EmitImpl(pipeline_loop_->min + pipeline_loop_->extent,
pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true, true);
Stmt prologue = EmitImpl(pipeline_loop_->min,
pipeline_loop_->min + max_stage_, true, true);
Stmt body =
EmitImpl(pipeline_loop_->min + max_stage_,
pipeline_loop_->min + pipeline_loop_->extent, false, false);
Stmt epilogue = EmitImpl(
pipeline_loop_->min + pipeline_loop_->extent,
pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true, true);
SeqStmt stmt = SeqStmt({prologue, body, epilogue});
// Step 3: Make a new block that contains new buffer allocations after pipeline rewriting.
// Step 3: Make a new block that contains new buffer allocations after
// pipeline rewriting.
Array<Buffer> alloc_buffers;
for (const auto& alloc : pipeline_allocs_) {
for (const auto &alloc : pipeline_allocs_) {
alloc_buffers.push_back(buffer_remap_.Get(alloc).value_or(alloc));
buffer_data_to_buffer_.erase(alloc->data);
}
......@@ -296,26 +316,28 @@ class PipelineRewriter : public StmtExprMutator {
return BlockRealize({}, Bool(true), block);
}
private:
private:
/*!
* \brief Analyze accesses to the buffers in the software pipeline.
*
* This method check the 'define' and 'use' stage of the buffers in the software pipeline, which
* can be used to compute the number of versions needed to maintain after rewriting.
* This method check the 'define' and 'use' stage of the buffers in the
* software pipeline, which can be used to compute the number of versions
* needed to maintain after rewriting.
*/
std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
GetBufferAccessInfo() {
std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual> infos;
for (const auto& pair : pipeline_info_) {
const Block& block = pair.first;
std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
infos;
for (const auto &pair : pipeline_info_) {
const Block &block = pair.first;
int stage = pair.second.stage;
max_stage_ = std::max(max_stage_, stage);
for (const BufferRegion& write : block->writes) {
for (const BufferRegion &write : block->writes) {
if (!infos.count(write->buffer)) {
infos.emplace(write->buffer, BufferAccessInfo{});
}
auto& info = infos.at(write->buffer);
auto &info = infos.at(write->buffer);
if (info.def == -1) {
info.def = stage;
} else {
......@@ -323,11 +345,11 @@ class PipelineRewriter : public StmtExprMutator {
}
}
for (const BufferRegion& read : block->reads) {
for (const BufferRegion &read : block->reads) {
if (!infos.count(read->buffer)) {
infos.emplace(read->buffer, BufferAccessInfo{});
}
auto& info = infos.at(read->buffer);
auto &info = infos.at(read->buffer);
info.use = std::max(info.use, stage);
}
}
......@@ -355,58 +377,64 @@ class PipelineRewriter : public StmtExprMutator {
}
/*!
* \brief Compute the number of versions need to maintain for buffer accessed in the software
* pipeline.
* \brief Compute the number of versions need to maintain for buffer accessed
* in the software pipeline.
*
* This method applies liveness analysis to the target buffer to compute the number of versions
* need to maintain during the software pipeline.
* Annotation `attr::double_buffer_scope` is handled here which provides a way to override the
* result of the analysis. Additional double buffering in the software pipeline can be useful
* to eliminate synchronizations in GPU devices.
* This method applies liveness analysis to the target buffer to compute the
* number of versions need to maintain during the software pipeline.
* Annotation `attr::double_buffer_scope` is handled here which provides a way
* to override the result of the analysis. Additional double buffering in the
* software pipeline can be useful to eliminate synchronizations in GPU
* devices.
*
* \param buffer The target buffer
* \param buffer_info The access information of the target buffer.
* \return The number of versions required for the target buffer.
*/
int ComputeBufferVersions(const Buffer& buffer, const BufferAccessInfo& buffer_info) {
int ComputeBufferVersions(const Buffer &buffer,
const BufferAccessInfo &buffer_info) {
if (buffer_info.def == -1) {
// Keep the original number of versions as buffers defined outside the software pipeline
// should not be mutated.
// Keep the original number of versions as buffers defined outside the
// software pipeline should not be mutated.
return 1;
}
// `use - def + 1` is a upper bound of the needed versions
// We optimize a few case where the number of versions can be smaller than the upper bound
// We optimize a few case where the number of versions can be smaller than
// the upper bound
int num_versions = buffer_info.use - buffer_info.def + 1;
if (num_versions >= 2) {
// A special case when `use - def + 1 == 2`. Double buffering is only needed in this case when
// these exists a reader block_i and a writer block_j such that
// order(block_i) < order(block_j) and stage(block_i) < stage(block_j) and the access regions
// of block_i and block_j overlap.
// A special case when `use - def + 1 == 2`. Double buffering is only
// needed in this case when these exists a reader block_i and a writer
// block_j such that order(block_i) < order(block_j) and stage(block_i) <
// stage(block_j) and the access regions of block_i and block_j overlap.
bool need_multi_version = false;
for (const auto& pair1 : pipeline_info_) {
const Block& writer_block = pair1.first;
const auto& writer_info = pair1.second;
for (const auto &pair1 : pipeline_info_) {
const Block &writer_block = pair1.first;
const auto &writer_info = pair1.second;
auto it1 = std::find_if(writer_block->writes.begin(), writer_block->writes.end(),
[&](const BufferRegion& buffer_region) {
auto it1 = std::find_if(writer_block->writes.begin(),
writer_block->writes.end(),
[&](const BufferRegion &buffer_region) {
return buffer_region->buffer.same_as(buffer);
});
if (it1 == writer_block->writes.end()) {
continue;
}
for (const auto& pair2 : pipeline_info_) {
const Block& reader_block = pair2.first;
const auto& reader_info = pair2.second;
auto it2 = std::find_if(reader_block->reads.begin(), reader_block->reads.end(),
[&](const BufferRegion& buffer_region) {
return buffer_region->buffer.same_as(buffer);
});
for (const auto &pair2 : pipeline_info_) {
const Block &reader_block = pair2.first;
const auto &reader_info = pair2.second;
auto it2 = std::find_if(
reader_block->reads.begin(), reader_block->reads.end(),
[&](const BufferRegion &buffer_region) {
return buffer_region->buffer.same_as(buffer);
});
if (it2 == reader_block->reads.end()) {
continue;
}
if (writer_info.order < reader_info.order && writer_info.stage < reader_info.stage &&
if (writer_info.order < reader_info.order &&
writer_info.stage < reader_info.stage &&
MayConflict((*it1)->region, (*it2)->region)) {
need_multi_version = true;
break;
......@@ -421,13 +449,12 @@ class PipelineRewriter : public StmtExprMutator {
}
/*!
* \brief Rewrite buffer allocation to keep multiple versions of original buffer for pipelined
* accesses.
* \param buffer The buffer to be resized.
* \brief Rewrite buffer allocation to keep multiple versions of original
* buffer for pipelined accesses. \param buffer The buffer to be resized.
* \param num_versions The number of versions to keep.
* \return The resized buffer.
*/
Buffer RewriteAllocBuffer(const Buffer& buffer, int num_versions) {
Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) {
ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get()));
new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
if (new_buffer->strides.size()) {
......@@ -438,29 +465,32 @@ class PipelineRewriter : public StmtExprMutator {
return Buffer(new_buffer);
}
// Per-stage states that need to be tracked across pipeline prologue, body, and epilogue.
// Per-stage states that need to be tracked across pipeline prologue, body,
// and epilogue.
struct AsyncStateGlobal {
// Buffers that this stage asynchronously writes.
std::unordered_set<const BufferNode*> dst_buffers;
// An imaginary index that the latest async operation associated with this stage has written
// into. Only valid if all associated predicates are true, so that we can count the number of
// async invocations exactly. When it is valid, it is the "sum of extents of loops that have
// been executed" - 1, e.g. for epilogue it is prologue extent + body extent - 1. This
// is only needed to compute wait count for epilogue without async producers.
std::unordered_set<const BufferNode *> dst_buffers;
// An imaginary index that the latest async operation associated with this
// stage has written into. Only valid if all associated predicates are true,
// so that we can count the number of async invocations exactly. When it is
// valid, it is the "sum of extents of loops that have been executed" - 1,
// e.g. for epilogue it is prologue extent + body extent - 1. This is only
// needed to compute wait count for epilogue without async producers.
PrimExpr producer_head;
std::vector<std::vector<int>> commit_groups;
std::unordered_map<const BufferNode*, int> buffer_to_commit_group_;
std::unordered_map<const BufferNode *, int> buffer_to_commit_group_;
bool writes(Buffer buf) const { return dst_buffers.count(buf.get()) > 0; }
};
// Per-stage states that are local to each of pipeline prologue, body, and epilogue.
// Per-stage states that are local to each of pipeline prologue, body, and
// epilogue.
struct AsyncStateLocal {
struct PendingWait {
// The index into a list of blocks, where async_wait_queue should be attached at the
// beginning.
// The index into a list of blocks, where async_wait_queue should be
// attached at the beginning.
int insert_before;
// in_flight_count would be a more precise name, but the implementation uses wait_count for
// brevity.
// in_flight_count would be a more precise name, but the implementation
// uses wait_count for brevity.
PrimExpr wait_count{nullptr};
bool valid() const { return wait_count.defined(); }
......@@ -468,8 +498,8 @@ class PipelineRewriter : public StmtExprMutator {
std::vector<PendingWait> pending_waits;
// A symbolic expression representing the index the latest async operation associated with this
// stage has written into, at the "current" iteration.
// A symbolic expression representing the index the latest async operation
// associated with this stage has written into, at the "current" iteration.
Optional<PrimExpr> producer_head;
};
......@@ -483,31 +513,35 @@ class PipelineRewriter : public StmtExprMutator {
bool is_async;
};
void PopulateWaitCounts(const std::vector<RewrittenBlockInfo>& new_blocks,
std::map<int, AsyncStateLocal>* async_states_local) {
void PopulateWaitCounts(const std::vector<RewrittenBlockInfo> &new_blocks,
std::map<int, AsyncStateLocal> *async_states_local) {
for (size_t i = 0; i < new_blocks.size(); ++i) {
int producer_stage_idx = -1;
for (auto read_region : new_blocks[i].block->reads) {
for (const auto& [stage, state] : async_states) {
if (stage <= new_blocks[i].stage && state.writes(read_region->buffer)) {
// Found an earlier stage where read_region->buffer was asynchronously written
for (const auto &[stage, state] : async_states) {
if (stage <= new_blocks[i].stage &&
state.writes(read_region->buffer)) {
// Found an earlier stage where read_region->buffer was
// asynchronously written
ICHECK(producer_stage_idx == -1 || producer_stage_idx == stage)
<< "A dependency on multiple async stages is not supported";
producer_stage_idx = stage;
}
}
}
if (producer_stage_idx == -1) continue;
const auto& state = async_states[producer_stage_idx];
auto& dep_local_state = (*async_states_local)[producer_stage_idx];
if (producer_stage_idx == -1)
continue;
const auto &state = async_states[producer_stage_idx];
auto &dep_local_state = (*async_states_local)[producer_stage_idx];
PrimExpr in_flight_cnt = 0;
for (const auto& group : state.commit_groups) {
for (const auto &group : state.commit_groups) {
PrimExpr consumer_head = new_blocks[i].access_index;
PrimExpr producer_head;
if (dep_local_state.producer_head.defined()) {
producer_head = dep_local_state.producer_head.value();
// if the group is after the wait point, minus by 1
if (group.front() > new_blocks[i].order) producer_head -= 1;
if (group.front() > new_blocks[i].order)
producer_head -= 1;
} else {
producer_head = state.producer_head;
}
......@@ -516,41 +550,43 @@ class PipelineRewriter : public StmtExprMutator {
// We can relax the in-flight-count by the number of independent commit.
std::unordered_set<int> dependent_groups;
for (const auto& read_region : new_blocks[i].block->reads) {
for (const auto &read_region : new_blocks[i].block->reads) {
if (state.buffer_to_commit_group_.count(read_region->buffer.get()))
dependent_groups.insert(state.buffer_to_commit_group_.at(read_region->buffer.get()));
dependent_groups.insert(
state.buffer_to_commit_group_.at(read_region->buffer.get()));
}
for (int i = int(state.commit_groups.size()) - 1; i >= 0; i--) {
if (dependent_groups.count(i) == 0)
in_flight_cnt += 1;
else
break; // stop relaxing
break; // stop relaxing
}
in_flight_cnt = analyzer_.Simplify(in_flight_cnt);
dep_local_state.pending_waits.push_back({static_cast<int>(i), in_flight_cnt});
dep_local_state.pending_waits.push_back(
{static_cast<int>(i), in_flight_cnt});
}
}
// Given pipelined blocks and async-related information, generate final loop statements with async
// scopes (if any).
// Given pipelined blocks and async-related information, generate final loop
// statements with async scopes (if any).
Array<Stmt> CompletePipelineLoopStatements(
const std::vector<RewrittenBlockInfo>& blocks,
const std::map<int, AsyncStateLocal>& async_states_local) const {
const std::vector<RewrittenBlockInfo> &blocks,
const std::map<int, AsyncStateLocal> &async_states_local) const {
std::vector<RewrittenBlockInfo> new_blocks = blocks;
for (const auto& [stage_id, state] : async_states_local) {
for (const auto& pw : state.pending_waits) {
auto& block = new_blocks[pw.insert_before].block;
BlockNode* n = block.CopyOnWrite();
for (const auto &[stage_id, state] : async_states_local) {
for (const auto &pw : state.pending_waits) {
auto &block = new_blocks[pw.insert_before].block;
BlockNode *n = block.CopyOnWrite();
auto zero = make_zero(DataType::Int(32));
n->body =
AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id,
AttrStmt(zero, tir::attr::async_wait_inflight_count, pw.wait_count, n->body));
n->body = AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id,
AttrStmt(zero, tir::attr::async_wait_inflight_count,
pw.wait_count, n->body));
}
}
// mark the last async stmt as commit
std::unordered_set<int> commit_group_indices;
for (const auto& [stage_id, state] : async_states) {
for (const auto &[stage_id, state] : async_states) {
for (size_t i = 0; i < state.commit_groups.size(); ++i) {
commit_group_indices.insert(state.commit_groups[i].back());
}
......@@ -561,9 +597,9 @@ class PipelineRewriter : public StmtExprMutator {
for (size_t i = 0; i < new_blocks.size(); i++) {
Block block = new_blocks[i].block;
if (commit_group_indices.count(new_blocks[i].order)) {
auto commit_queue_scope =
AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_commit_queue_scope,
new_blocks[i].stage, block->body);
auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)),
tir::attr::async_commit_queue_scope,
new_blocks[i].stage, block->body);
block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_);
}
stmts.push_back(BlockRealize({}, new_blocks[i].predicate, block));
......@@ -579,15 +615,18 @@ class PipelineRewriter : public StmtExprMutator {
* \param unroll_loop Whether the loop should be unrolled.
* \return The result loop.
*/
Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop, bool need_bound_check) {
Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop,
bool need_bound_check) {
PrimExpr new_loop_var;
PrimExpr extent = end - start;
auto make_nop = []() { return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {})); };
auto make_nop = []() {
return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {}));
};
bool is_unit_loop = analyzer_.CanProveEqual(extent, 1);
if (is_unit_loop) {
new_loop_var = start; // use constants as the loop var for unit loops
new_loop_var = start; // use constants as the loop var for unit loops
} else {
new_loop_var = pipeline_loop_->loop_var.copy_with_suffix("");
analyzer_.Bind(Downcast<Var>(new_loop_var), Range(start, end));
......@@ -598,45 +637,52 @@ class PipelineRewriter : public StmtExprMutator {
// Async related
std::map<int, AsyncStateLocal> async_states_local;
for (const Block& block : ordered_stmts_) {
for (const Block &block : ordered_stmts_) {
int stage = pipeline_info_.at(block).stage;
int order = pipeline_info_.at(block).order;
PrimExpr inbound = Bool(true);
PrimExpr skewed_loop_var = new_loop_var - stage;
if (need_bound_check)
inbound = analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) &&
(skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent);
inbound =
analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) &&
(skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent);
if (analyzer_.CanProve(!inbound)) {
continue;
}
Block new_block = Downcast<Block>(PipelineBodyRewriter(
buffer_data_to_buffer_, buffer_remap_, pipeline_loop_, max_stage_ != 1)(block));
Block new_block = Downcast<Block>(
PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_,
pipeline_loop_, max_stage_ != 1)(block));
PrimExpr delta = start - pipeline_loop_->min;
// This variable corresponds to
// - "producer_head" if this stage is an async producer
// - "consumer_head" if this stage reads from asynchronously written buffers.
PrimExpr normalized_access_index = is_unit_loop ? skewed_loop_var : skewed_loop_var + delta;
// - "consumer_head" if this stage reads from asynchronously written
// buffers.
PrimExpr normalized_access_index =
is_unit_loop ? skewed_loop_var : skewed_loop_var + delta;
// Adjust the block predicate and the body according to the final loop bound
// Adjust the block predicate and the body according to the final loop
// bound
// [pipeline_loop_->min, extent).
if (!is_unit_loop) {
Var loop_iter = Downcast<Var>(new_loop_var);
inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}});
}
new_block = Downcast<Block>(
Substitute(new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
new_block = Downcast<Block>(Substitute(
new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
if (pipeline_info_[block].async) {
auto& local_state = async_states_local[stage];
auto &local_state = async_states_local[stage];
local_state.producer_head = normalized_access_index;
BlockNode* n = new_block.CopyOnWrite();
n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope, 1, n->body);
BlockNode *n = new_block.CopyOnWrite();
n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope,
1, n->body);
}
new_blocks.push_back(
{stage, order, inbound, new_block, normalized_access_index, pipeline_info_[block].async});
new_blocks.push_back({stage, order, inbound, new_block,
normalized_access_index,
pipeline_info_[block].async});
}
PopulateWaitCounts(new_blocks, &async_states_local);
......@@ -655,8 +701,8 @@ class PipelineRewriter : public StmtExprMutator {
if (!is_unit_loop) {
Map<String, ObjectRef> preserved_annotations;
for (const auto& kv : pipeline_loop_->annotations) {
const String& key = kv.first;
for (const auto &kv : pipeline_loop_->annotations) {
const String &key = kv.first;
if (kv.first != tir::attr::software_pipeline_stage &&
kv.first != tir::attr::software_pipeline_order &&
kv.first != tir::attr::software_pipeline_async_stages) {
......@@ -664,16 +710,17 @@ class PipelineRewriter : public StmtExprMutator {
}
}
new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent,
unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, std::move(new_loop),
NullOpt, preserved_annotations);
unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind,
std::move(new_loop), NullOpt, preserved_annotations);
}
// Update producer heads in the global async states.
for (const auto& [stage_id, state] : async_states_local) {
for (const auto &[stage_id, state] : async_states_local) {
async_states[stage_id].producer_head += extent;
}
return BlockRealize({}, Bool(true), MakeBlock(std::move(new_loop), buffer_data_to_buffer_));
return BlockRealize({}, Bool(true),
MakeBlock(std::move(new_loop), buffer_data_to_buffer_));
}
arith::Analyzer analyzer_;
......@@ -690,22 +737,23 @@ class PipelineRewriter : public StmtExprMutator {
/*!
* \brief Build the dependency graph among a array of blocks.
* \param[in] blocks The array of blocks.
* \param[out] dep_src2dst Optional, a map to store dependency edges from the source to the
* destination.
* \param[out] dep_dst2src Optional, a map to store dependency edges from the
* destination to the source.
* \param[out] dep_src2dst Optional, a map to store dependency edges from the
* source to the destination. \param[out] dep_dst2src Optional, a map to store
* dependency edges from the destination to the source.
*/
void BuildDependencyGraph(
const Array<Block>& blocks,
std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>* dep_src2dst,
std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>* dep_dst2src) {
std::unordered_map<Var, Array<Block>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;
for (const Block& block : blocks) {
for (const BufferRegion& read : block->reads) {
void BuildDependencyGraph(const Array<Block> &blocks,
std::unordered_map<Block, Array<Block>, ObjectPtrHash,
ObjectPtrEqual> *dep_src2dst,
std::unordered_map<Block, Array<Block>, ObjectPtrHash,
ObjectPtrEqual> *dep_dst2src) {
std::unordered_map<Var, Array<Block>, ObjectPtrHash, ObjectPtrEqual>
buffer_writers;
for (const Block &block : blocks) {
for (const BufferRegion &read : block->reads) {
auto it = buffer_writers.find(read->buffer->data);
if (it != buffer_writers.end()) {
for (const Block& writer : it->second) {
for (const Block &writer : it->second) {
if (dep_src2dst != nullptr) {
(*dep_src2dst)[writer].push_back(block);
}
......@@ -715,83 +763,89 @@ void BuildDependencyGraph(
}
}
}
for (const BufferRegion& write : block->writes) {
for (const BufferRegion &write : block->writes) {
buffer_writers[write->buffer->data].push_back(block);
}
}
}
class PipelineInjector : private StmtExprMutator {
public:
static Stmt Inject(const PrimFunc& func) {
public:
static Stmt Inject(const PrimFunc &func) {
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
PipelineInjector injector(global_symbol);
for (const auto& kv : func->buffer_map) {
const Buffer& buffer = kv.second;
for (const auto &kv : func->buffer_map) {
const Buffer &buffer = kv.second;
injector.buffer_data_to_buffer_.Set(buffer->data, buffer);
}
return injector(func->body);
}
private:
explicit PipelineInjector(Optional<String> global_symbol) : global_symbol_(global_symbol) {}
private:
explicit PipelineInjector(Optional<String> global_symbol)
: global_symbol_(global_symbol) {}
/*!
* \brief Check the pipeline satisfies the following conditions:
* 1. No conflicting order: The order of each statement should be unique.
* 2. Reordering of statements doesn't break buffer access dependencies. Specifically, for
* dependency (e.g. read-after-write) from statement A to statement B, it requires:
* case 1: stage(A) < stage(B)
* case 2: stage(A) == stage(B) and order(A) < order(B)
* 2. Reordering of statements doesn't break buffer access dependencies.
* Specifically, for dependency (e.g. read-after-write) from statement A to
* statement B, it requires: case 1: stage(A) < stage(B) case 2: stage(A) ==
* stage(B) and order(A) < order(B)
*/
void ValidatePipelineBody(const PipelineInfo& pipeline_info, const Array<Block>& original_order) {
void ValidatePipelineBody(const PipelineInfo &pipeline_info,
const Array<Block> &original_order) {
std::unordered_set<int> used_orders;
std::unordered_map<int, int> stage_max_order;
std::unordered_map<int, const Block*> order_to_block;
std::unordered_map<const Block*, int> block_to_stage;
for (const Block& block : original_order) {
const auto& stmt_info = pipeline_info.at(block);
std::unordered_map<int, const Block *> order_to_block;
std::unordered_map<const Block *, int> block_to_stage;
for (const Block &block : original_order) {
const auto &stmt_info = pipeline_info.at(block);
int order = stmt_info.order;
CHECK(!used_orders.count(order))
<< "ValueError: Two statements in the software pipeline cannot have the same order";
<< "ValueError: Two statements in the software pipeline cannot have "
"the same order";
used_orders.insert(order);
}
std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual> dep_src2dst;
std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>
dep_src2dst;
BuildDependencyGraph(original_order, &dep_src2dst, nullptr);
for (const auto& pair : dep_src2dst) {
const Block& src = pair.first;
const auto& src_info = pipeline_info.at(src);
const Array<Block>& dsts = pair.second;
for (const Block& dst : dsts) {
const auto& dst_info = pipeline_info.at(dst);
for (const auto &pair : dep_src2dst) {
const Block &src = pair.first;
const auto &src_info = pipeline_info.at(src);
const Array<Block> &dsts = pair.second;
for (const Block &dst : dsts) {
const auto &dst_info = pipeline_info.at(dst);
CHECK_LE(src_info.stage, dst_info.stage)
<< "ValueError: statement " << dst << " in stage " << dst_info.stage
<< " cannot depends on statement " << src << " in a later stage " << src_info.stage;
<< " cannot depends on statement " << src << " in a later stage "
<< src_info.stage;
if (src_info.stage == dst_info.stage) {
CHECK_LT(src_info.order, dst_info.order) << "ValueError: two statements with buffer "
"access dependency in the same stage of the "
"software pipeline cannot be reordered";
CHECK_LT(src_info.order, dst_info.order)
<< "ValueError: two statements with buffer "
"access dependency in the same stage of the "
"software pipeline cannot be reordered";
}
}
}
}
Stmt VisitStmt_(const ForNode* op) final {
Stmt VisitStmt_(const ForNode *op) final {
// Step 1: Recursively rewrite the children first.
For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op));
if (!HasPipelineAnnotation(op)) {
return std::move(for_node);
}
// Step 2: Find the body and buffer allocations of the pipeline. The body can be direct child of
// the for-loop. If the for-loop has BlockRealize as its child, the pipeline body will be the
// child of the block.
// Step 2: Find the body and buffer allocations of the pipeline. The body
// can be direct child of the for-loop. If the for-loop has BlockRealize as
// its child, the pipeline body will be the child of the block.
Stmt pipeline_body{nullptr};
Array<Buffer> pipeline_allocs;
if (const auto* realize = for_node->body.as<BlockRealizeNode>()) {
const auto& block = realize->block;
for (const auto& buffer : block->alloc_buffers) {
if (const auto *realize = for_node->body.as<BlockRealizeNode>()) {
const auto &block = realize->block;
for (const auto &buffer : block->alloc_buffers) {
ICHECK(buffer->IsInstance<BufferNode>());
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
......@@ -801,31 +855,32 @@ class PipelineInjector : private StmtExprMutator {
pipeline_body = for_node->body;
}
const SeqStmtNode* pipeline_body_seq = pipeline_body.as<SeqStmtNode>();
CHECK(pipeline_body_seq)
<< "ValueError: The body of the software pipeline should be SeqStmt, got "
<< pipeline_body->GetTypeKey();
const SeqStmtNode *pipeline_body_seq = pipeline_body.as<SeqStmtNode>();
CHECK(pipeline_body_seq) << "ValueError: The body of the software pipeline "
"should be SeqStmt, got "
<< pipeline_body->GetTypeKey();
// Step 3: Blockize the components of the pipeline. Each child of the pipelined loop will be
// converted into a block.
// Step 3: Blockize the components of the pipeline. Each child of the
// pipelined loop will be converted into a block.
PipelineInfo pipeline_info;
Array<Block> original_order; // pipeline body blocks in the original order
Array<Block> original_order; // pipeline body blocks in the original order
auto f_add_child = [&](const Stmt& child) {
auto f_add_child = [&](const Stmt &child) {
original_order.push_back(MakeBlock(child, buffer_data_to_buffer_));
};
for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) {
const auto* nested_block_realize = pipeline_body_seq->seq[i].as<BlockRealizeNode>();
const auto *nested_block_realize =
pipeline_body_seq->seq[i].as<BlockRealizeNode>();
if (nested_block_realize && is_one(nested_block_realize->predicate) &&
nested_block_realize->block->body->IsInstance<SeqStmtNode>()) {
const Block& nested_pipeline_block = nested_block_realize->block;
ICHECK(
nested_pipeline_block->match_buffers.empty()); // match_buffer should have been lowered
for (const auto& buffer : nested_pipeline_block->alloc_buffers) {
const Block &nested_pipeline_block = nested_block_realize->block;
ICHECK(nested_pipeline_block->match_buffers
.empty()); // match_buffer should have been lowered
for (const auto &buffer : nested_pipeline_block->alloc_buffers) {
pipeline_allocs.push_back(buffer);
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
const auto* nested_seq = nested_pipeline_block->body.as<SeqStmtNode>();
const auto *nested_seq = nested_pipeline_block->body.as<SeqStmtNode>();
for (size_t j = 0; j < nested_seq->seq.size(); j++) {
f_add_child(nested_seq->seq[j]);
}
......@@ -834,21 +889,26 @@ class PipelineInjector : private StmtExprMutator {
}
}
auto pipeline_stages =
Downcast<Array<Integer>>(op->annotations.at(tir::attr::software_pipeline_stage));
auto pipeline_orders =
Downcast<Array<Integer>>(op->annotations.at(tir::attr::software_pipeline_order));
auto pipeline_stages = Downcast<Array<Integer>>(
op->annotations.at(tir::attr::software_pipeline_stage));
auto pipeline_orders = Downcast<Array<Integer>>(
op->annotations.at(tir::attr::software_pipeline_order));
CHECK_EQ(pipeline_stages.size(), original_order.size())
<< "PrimFunc " << global_symbol_ << " has original order "
<< original_order.Map([](const auto& block) { return block->name_hint; })
<< ", but pipeline annotation is " << pipeline_stages << " with different size";
<< original_order.Map(
[](const auto &block) { return block->name_hint; })
<< ", but pipeline annotation is " << pipeline_stages
<< " with different size";
CHECK_EQ(pipeline_orders.size(), original_order.size())
<< "PrimFunc " << global_symbol_ << " has original order "
<< original_order.Map([](const auto& block) { return block->name_hint; })
<< ", but pipeline annotation is " << pipeline_orders << " with different size";
<< original_order.Map(
[](const auto &block) { return block->name_hint; })
<< ", but pipeline annotation is " << pipeline_orders
<< " with different size";
std::unordered_set<int> pipeline_async_stages;
if (auto annot = op->annotations.Get(tir::attr::software_pipeline_async_stages)) {
if (auto annot =
op->annotations.Get(tir::attr::software_pipeline_async_stages)) {
for (auto s : Downcast<Array<Integer>>(annot)) {
pipeline_async_stages.insert(s->value);
}
......@@ -856,43 +916,44 @@ class PipelineInjector : private StmtExprMutator {
for (size_t i = 0; i < pipeline_stages.size(); i++) {
int stage = static_cast<int>(pipeline_stages[i]->value);
bool is_async = pipeline_async_stages.find(stage) != pipeline_async_stages.end();
PipelineAnnotation stage_order{stage,
/*order=*/static_cast<int>(pipeline_orders[i]->value),
is_async};
bool is_async =
pipeline_async_stages.find(stage) != pipeline_async_stages.end();
PipelineAnnotation stage_order{
stage,
/*order=*/static_cast<int>(pipeline_orders[i]->value), is_async};
pipeline_info.emplace(original_order[i], stage_order);
}
ValidatePipelineBody(pipeline_info, original_order);
// Step 4: Rewrite the pipeline body.
Stmt pipeline =
PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs, GetRef<For>(op), pipeline_info)
.BuildPipeline();
Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
GetRef<For>(op), pipeline_info)
.BuildPipeline();
if (const auto* realize = op->body.as<BlockRealizeNode>()) {
const auto& block = realize->block;
for (const auto& buffer : block->alloc_buffers) {
if (const auto *realize = op->body.as<BlockRealizeNode>()) {
const auto &block = realize->block;
for (const auto &buffer : block->alloc_buffers) {
buffer_data_to_buffer_.erase(buffer->data);
}
}
return pipeline;
}
Stmt VisitStmt_(const BlockNode* op) final {
for (const auto& buffer : op->alloc_buffers) {
Stmt VisitStmt_(const BlockNode *op) final {
for (const auto &buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
for (const auto& buffer : op->alloc_buffers) {
for (const auto &buffer : op->alloc_buffers) {
buffer_data_to_buffer_.erase(buffer->data);
}
return std::move(block);
}
bool HasPipelineAnnotation(const ForNode* op) const {
bool HasPipelineAnnotation(const ForNode *op) const {
auto it1 = op->annotations.find(tir::attr::software_pipeline_stage);
auto it2 = op->annotations.find(tir::attr::software_pipeline_order);
bool has_stage = it1 != op->annotations.end();
......@@ -901,10 +962,12 @@ class PipelineInjector : private StmtExprMutator {
return true;
}
if (has_stage) {
LOG(FATAL) << "ValueError: Order of the software pipeline is not defined.";
LOG(FATAL)
<< "ValueError: Order of the software pipeline is not defined.";
}
if (has_order) {
LOG(FATAL) << "ValueError: Stage of the software pipeline is not defined.";
LOG(FATAL)
<< "ValueError: Stage of the software pipeline is not defined.";
}
return false;
}
......@@ -914,13 +977,13 @@ class PipelineInjector : private StmtExprMutator {
};
/*!
* \brief Transform annotated loops into pipelined one that parallelize producers and consumers.
* \return The IR transform pass.
* \brief Transform annotated loops into pipelined one that parallelize
* producers and consumers. \return The IR transform pass.
*/
tir::transform::Pass InjectSoftwarePipeline() {
using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* fptr = f.CopyOnWrite();
auto *fptr = f.CopyOnWrite();
fptr->body = PipelineInjector::Inject(f);
fptr->body = ConvertSSA(std::move(fptr->body));
return f;
......@@ -931,5 +994,5 @@ tir::transform::Pass InjectSoftwarePipeline() {
TVM_REGISTER_GLOBAL("tl.transform.InjectSoftwarePipeline")
.set_body_typed(InjectSoftwarePipeline);
} // namespace tl
} // namespace tvm
} // namespace tl
} // namespace tvm
......@@ -30,11 +30,11 @@
#include <queue>
#include "arith/ir_mutator_with_analyzer.h"
#include "../op/parallel.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "common/loop_fusion_utils.h"
#include "loop_partition.h"
#include "loop_vectorize.h"
#include "common/loop_fusion_utils.h"
namespace tvm {
namespace tl {
......@@ -49,7 +49,7 @@ struct LayoutInferenceResult {
};
class BufferUseDefCollector : public StmtExprVisitor {
public:
public:
BufferUseDefCollector() = default;
LayoutInferenceResult Run() {
......@@ -59,22 +59,27 @@ class BufferUseDefCollector : public StmtExprVisitor {
// maintain a bfs queue and infer common layout
std::queue<int> q;
std::vector<bool> in_queue(num_infer, true);
for (int i = 0; i < num_infer; i++) q.push(i);
for (int i = 0; i < num_infer; i++)
q.push(i);
auto run_infer_step = [&](int cur_infer_id, InferLevel level, bool update_queue) {
auto& next = infer_list_[cur_infer_id];
auto run_infer_step = [&](int cur_infer_id, InferLevel level,
bool update_queue) {
auto &next = infer_list_[cur_infer_id];
auto iter_var = thread_var_vec_[cur_infer_id];
auto updates = next->InferLayout(
LayoutInferArgs{target_, static_cast<size_t>(*as_const_int(iter_var->dom->extent)),
layout_map},
LayoutInferArgs{
target_,
static_cast<size_t>(*as_const_int(iter_var->dom->extent)),
layout_map},
level);
for (const auto& [buffer, layout] : updates) {
for (const auto &[buffer, layout] : updates) {
if (layout_map.count(buffer)) {
ICHECK(StructuralEqual()(layout, layout_map[buffer]))
<< "Get different layout for " << buffer;
} else {
layout_map.Set(buffer, layout);
if (!update_queue) continue;
if (!update_queue)
continue;
for (int idx : use_list_[buffer]) {
if (!in_queue[idx] && idx != cur_infer_id) {
in_queue[idx] = true;
......@@ -108,16 +113,17 @@ class BufferUseDefCollector : public StmtExprVisitor {
}
// Check that all fragments have been inferred
for (const auto& [buffer, _] : use_list_) {
for (const auto &[buffer, _] : use_list_) {
if (buffer.scope() == "local.fragment" && layout_map.count(buffer) == 0)
LOG_ERROR << "The layout for fragment " << buffer << " can not be inferred correctly.";
LOG_ERROR << "The layout for fragment " << buffer
<< " can not be inferred correctly.";
}
// Collect the layout for for nodes
Map<For, Fragment> for_map;
Map<For, PrimExpr> predicate_map;
for (auto& base_infer : infer_list_) {
if (auto for_infer = dynamic_cast<ParallelOp*>(base_infer.get())) {
for (auto &base_infer : infer_list_) {
if (auto for_infer = dynamic_cast<ParallelOp *>(base_infer.get())) {
ICHECK(for_infer->GetLoopLayout().defined())
<< "The Layout for Parallel for can not be inferred correctly : \n"
<< for_infer->GetRoot();
......@@ -130,25 +136,27 @@ class BufferUseDefCollector : public StmtExprVisitor {
return {layout_map, for_map, predicate_map};
}
void Collect(const PrimFunc& f) {
for (const auto& [_, buffer] : f->buffer_map) {
void Collect(const PrimFunc &f) {
for (const auto &[_, buffer] : f->buffer_map) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined()) << "Layout_Inference: Require the target attribute";
ICHECK(target.defined())
<< "Layout_Inference: Require the target attribute";
target_ = target.value();
this->operator()(f->body);
}
private:
void VisitExpr_(const CallNode* op) final {
private:
void VisitExpr_(const CallNode *op) final {
StmtExprVisitor::VisitExpr_(op);
// Do not analysis the call node to the global function.
if (op->op.as<GlobalVarNode>()) return;
if (op->op.as<GlobalVarNode>())
return;
auto p = ParseOperator(GetRef<Call>(op), buffer_data_to_buffer_);
if (p != nullptr) {
for (const auto& arg : op->args) {
for (const auto &arg : op->args) {
if (auto buffer = getBufferFromAccessPtr(arg)) {
addToUseList(buffer.value());
}
......@@ -158,7 +166,7 @@ class BufferUseDefCollector : public StmtExprVisitor {
}
}
Optional<Buffer> getBufferFromAccessPtr(const PrimExpr& expr) {
Optional<Buffer> getBufferFromAccessPtr(const PrimExpr &expr) {
auto call = expr.as<CallNode>();
if (call && call->op.same_as(builtin::tvm_access_ptr())) {
auto var = call->args[1].as<Var>().value();
......@@ -167,7 +175,7 @@ class BufferUseDefCollector : public StmtExprVisitor {
return NullOpt;
}
void addToUseList(const Buffer& buffer) {
void addToUseList(const Buffer &buffer) {
int infer_idx = infer_list_.size();
if (use_list_.find(buffer) == use_list_.end()) {
use_list_[buffer] = {};
......@@ -175,10 +183,10 @@ class BufferUseDefCollector : public StmtExprVisitor {
use_list_[buffer].push_back(infer_idx);
}
void VisitStmt_(const ForNode* op) final {
void VisitStmt_(const ForNode *op) final {
if (op->kind == ForKind::kParallel) {
auto infer = std::make_unique<ParallelOp>(GetRef<For>(op));
for (const auto& [buffer, _] : infer->GetIndiceMap()) {
for (const auto &[buffer, _] : infer->GetIndiceMap()) {
addToUseList(buffer);
}
infer_list_.push_back(std::move(infer));
......@@ -188,13 +196,14 @@ class BufferUseDefCollector : public StmtExprVisitor {
}
}
void VisitStmt_(const BlockNode* op) final {
void VisitStmt_(const BlockNode *op) final {
for (auto buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
if (op->annotations.count(attr::kLayoutMap)) {
auto map = op->annotations.Get(attr::kLayoutMap).as<Map<Var, Layout>>().value();
for (const auto& [var, layout] : map) {
auto map =
op->annotations.Get(attr::kLayoutMap).as<Map<Var, Layout>>().value();
for (const auto &[var, layout] : map) {
auto buffer = buffer_data_to_buffer_[var];
ICHECK(StructuralEqual()(layout->InputShape(), buffer->shape));
annotated_layout_map_.Set(buffer, layout);
......@@ -203,7 +212,7 @@ class BufferUseDefCollector : public StmtExprVisitor {
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const AttrStmtNode* op) final {
void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
......@@ -216,7 +225,8 @@ class BufferUseDefCollector : public StmtExprVisitor {
Map<Var, Buffer> buffer_data_to_buffer_;
std::vector<std::unique_ptr<Operator>> infer_list_;
std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual> use_list_;
std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual>
use_list_;
IterVar thread_var_;
std::vector<IterVar> thread_var_vec_;
Target target_;
......@@ -224,10 +234,10 @@ class BufferUseDefCollector : public StmtExprVisitor {
};
class LayoutInferencer : public IRMutatorWithAnalyzer {
public:
public:
static PrimFunc Substitute(PrimFunc f) {
arith::Analyzer analyzer;
PrimFuncNode* fptr = f.CopyOnWrite();
PrimFuncNode *fptr = f.CopyOnWrite();
fptr->body = ParallelLoopFuser::Fuse(f->body);
BufferUseDefCollector collector;
collector.Collect(f);
......@@ -237,11 +247,12 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
return f;
}
private:
LayoutInferencer(const LayoutInferenceResult result, arith::Analyzer* analyzer)
: arith::IRMutatorWithAnalyzer(analyzer), result_(result) {};
private:
LayoutInferencer(const LayoutInferenceResult result,
arith::Analyzer *analyzer)
: arith::IRMutatorWithAnalyzer(analyzer), result_(result){};
Stmt VisitStmt_(const BlockNode* op) final {
Stmt VisitStmt_(const BlockNode *op) final {
Block block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op));
for (auto buffer : block->alloc_buffers) {
......@@ -255,11 +266,12 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
return block;
}
Stmt VisitStmt_(const ForNode* op) final {
Stmt VisitStmt_(const ForNode *op) final {
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
if (result_.for_map.count(GetRef<For>(op))) {
auto loop_layout = result_.for_map[GetRef<For>(op)];
for_node = PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
for_node =
PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
for_node = VectorizeLoop(for_node);
if (result_.predicate_map.count(GetRef<For>(op))) {
return IfThenElse(result_.predicate_map[GetRef<For>(op)], for_node);
......@@ -270,7 +282,7 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
return for_node;
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
ICHECK_NE(iv->thread_tag.length(), 0U);
......@@ -281,7 +293,7 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
private:
private:
const LayoutInferenceResult result_;
IterVar thread_var_;
};
......@@ -297,5 +309,5 @@ tvm::transform::Pass LayoutInference() {
TVM_REGISTER_GLOBAL("tl.transform.LayoutInference")
.set_body_typed(LayoutInference);
} // namespace tl
} // namespace tvm
} // namespace tl
} // namespace tvm
......@@ -30,8 +30,8 @@
#include <queue>
#include "arith/ir_mutator_with_analyzer.h"
#include "../op/parallel.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "loop_partition.h"
#include "loop_vectorize.h"
......@@ -43,11 +43,11 @@ using arith::IRMutatorWithAnalyzer;
// Helper class to find leaf For nodes in a given IR
class LeafForFinder : public StmtVisitor {
public:
public:
std::vector<For> leaf_for_nodes;
private:
void VisitStmt_(const ForNode* op) final {
private:
void VisitStmt_(const ForNode *op) final {
has_child_for_ = false;
bool parent_has_child_for = parent_has_child_for_;
parent_has_child_for_ = false;
......@@ -62,7 +62,7 @@ class LeafForFinder : public StmtVisitor {
parent_has_child_for_ = true;
}
private:
private:
bool has_child_for_ = false;
bool parent_has_child_for_ = false;
};
......@@ -75,11 +75,11 @@ class LeafForFinder : public StmtVisitor {
// If the index might exceed the shape (upper bound too large),
// log a warning or handle accordingly.
struct GlobalMemChecker : public StmtExprVisitor {
arith::Analyzer* analyzer;
arith::Analyzer *analyzer;
explicit GlobalMemChecker(arith::Analyzer* analyzer) : analyzer(analyzer) {}
explicit GlobalMemChecker(arith::Analyzer *analyzer) : analyzer(analyzer) {}
void VisitExpr_(const BufferLoadNode* op) final {
void VisitExpr_(const BufferLoadNode *op) final {
// Check if the buffer is in global scope
if (IsGlobalBuffer(op->buffer)) {
CheckBufferIndices(op->buffer, op->indices, /*is_load=*/true);
......@@ -87,7 +87,7 @@ struct GlobalMemChecker : public StmtExprVisitor {
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode* op) final {
void VisitStmt_(const BufferStoreNode *op) final {
// Check if the buffer is in global scope
if (IsGlobalBuffer(op->buffer)) {
CheckBufferIndices(op->buffer, op->indices, /*is_load=*/false);
......@@ -96,21 +96,24 @@ struct GlobalMemChecker : public StmtExprVisitor {
}
// Helper function to determine if a buffer is global
bool IsGlobalBuffer(const Buffer& buffer) {
// The storage scope is often encoded in the buffer->data var name or associated attributes.
// In typical TVM IR, global buffers have scope "global".
// Here we assume a helper function GetPtrStorageScope is available.
// If not, you might need to parse buffer->data->name_hint or associated attributes.
bool IsGlobalBuffer(const Buffer &buffer) {
// The storage scope is often encoded in the buffer->data var name or
// associated attributes. In typical TVM IR, global buffers have scope
// "global". Here we assume a helper function GetPtrStorageScope is
// available. If not, you might need to parse buffer->data->name_hint or
// associated attributes.
String scope = buffer.scope();
return scope == "global";
}
// Check each index against the buffer shape dimensions
void CheckBufferIndices(const Buffer& buffer, const Array<PrimExpr>& indices, bool is_load) {
void CheckBufferIndices(const Buffer &buffer, const Array<PrimExpr> &indices,
bool is_load) {
// Ensure indices count matches buffer dimension
if (indices.size() != buffer->shape.size()) {
LOG(WARNING) << "Buffer access dimension mismatch: indices size (" << indices.size()
<< ") vs. shape size (" << buffer->shape.size() << ")";
LOG(WARNING) << "Buffer access dimension mismatch: indices size ("
<< indices.size() << ") vs. shape size ("
<< buffer->shape.size() << ")";
return;
}
......@@ -130,18 +133,19 @@ struct GlobalMemChecker : public StmtExprVisitor {
Array<PrimExpr> GetConditions() { return _conditions; }
private:
private:
Array<PrimExpr> _conditions;
};
class SafeMemorysRewriter : public StmtExprMutator {
arith::Analyzer* analyzer_;
arith::Analyzer *analyzer_;
public:
explicit SafeMemorysRewriter(arith::Analyzer* analyzer) : analyzer_(analyzer) {}
public:
explicit SafeMemorysRewriter(arith::Analyzer *analyzer)
: analyzer_(analyzer) {}
private:
Stmt VisitStmt_(const BufferStoreNode* op) final {
private:
Stmt VisitStmt_(const BufferStoreNode *op) final {
// Check if the buffer is in global scope
auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
GlobalMemChecker checker(analyzer_);
......@@ -173,12 +177,13 @@ class SafeMemorysRewriter : public StmtExprMutator {
// Handle Call Nodes
// For example
// T.call_extern("handle", "atomicAddx2", T.address_of(C), T.address_of(C_shared))
Stmt VisitStmt_(const EvaluateNode* op) final {
// T.call_extern("handle", "atomicAddx2", T.address_of(C),
// T.address_of(C_shared))
Stmt VisitStmt_(const EvaluateNode *op) final {
auto evaluate = Downcast<Evaluate>(StmtExprMutator::VisitStmt_(op));
auto call = Downcast<Call>(evaluate->value);
if (call.defined() && call->op == builtin::call_extern()) {
GlobalMemChecker checker(analyzer_);
checker(call);
Array<PrimExpr> conditions = checker.GetConditions();
......@@ -197,13 +202,12 @@ class SafeMemorysRewriter : public StmtExprMutator {
return evaluate;
}
bool isSharedBuffer(const Buffer& buffer) {
bool isSharedBuffer(const Buffer &buffer) {
String scope = buffer.scope();
return scope == "shared" || scope == "shared.dyn";
}
bool IsGlobalBuffer(const Buffer& buffer) {
bool IsGlobalBuffer(const Buffer &buffer) {
String scope = buffer.scope();
return scope == "global";
}
......@@ -211,32 +215,34 @@ class SafeMemorysRewriter : public StmtExprMutator {
// Class to legalize safe memory access by transforming them appropriately
class SafeMemoryLegalizer : IRMutatorWithAnalyzer {
public:
public:
// Static method to substitute and transform the given PrimFunc
static PrimFunc Substitute(PrimFunc f) {
arith::Analyzer analyzer;
// Create an instance of the legalizer with the analyzer
SafeMemoryLegalizer substituter(&analyzer);
// Get a mutable copy of the function node
PrimFuncNode* fptr = f.CopyOnWrite();
PrimFuncNode *fptr = f.CopyOnWrite();
// Apply the legalizer to the function body
fptr->body = substituter.VisitStmt(f->body);
return f;
}
private:
private:
// Constructor initializing the base class with the analyzer
SafeMemoryLegalizer(arith::Analyzer* analyzer) : arith::IRMutatorWithAnalyzer(analyzer) {}
SafeMemoryLegalizer(arith::Analyzer *analyzer)
: arith::IRMutatorWithAnalyzer(analyzer) {}
// Override the VisitStmt_ method to handle ForNode (loop statements)
Stmt VisitStmt_(const ForNode* op) final {
Stmt VisitStmt_(const ForNode *op) final {
// Visit and potentially modify the loop node
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
auto has_inner_loop = HasInnerLoop(for_node->body);
if (!has_inner_loop) {
SafeMemorysRewriter rewriter(analyzer_);
for_node.CopyOnWrite()->body = rewriter(for_node->body);
// // Detect Buffer Load Node in the loop body, collect the indices and buffer size
// // Detect Buffer Load Node in the loop body, collect the indices and
// buffer size
// // Run the checker on the loop body
// GlobalMemChecker checker(analyzer_);
......@@ -257,7 +263,7 @@ class SafeMemoryLegalizer : IRMutatorWithAnalyzer {
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
static bool HasInnerLoop(const Stmt& stmt) {
static bool HasInnerLoop(const Stmt &stmt) {
LeafForFinder finder;
finder(stmt);
return finder.leaf_for_nodes.size() > 0;
......@@ -279,5 +285,5 @@ tvm::transform::Pass LegalizeSafeMemoryAccess() {
TVM_REGISTER_GLOBAL("tl.transform.LegalizeSafeMemoryAccess")
.set_body_typed(LegalizeSafeMemoryAccess);
} // namespace tl
} // namespace tvm
} // namespace tl
} // namespace tvm
......@@ -30,8 +30,8 @@
#include <queue>
#include "arith/ir_mutator_with_analyzer.h"
#include "../op/parallel.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "loop_partition.h"
#include "loop_vectorize.h"
......@@ -43,25 +43,26 @@ using arith::IRMutatorWithAnalyzer;
// Class to legalize vectorized loops by transforming them appropriately
class LoopVectorizedLegalizer : IRMutatorWithAnalyzer {
public:
public:
// Static method to substitute and transform the given PrimFunc
static PrimFunc Substitute(PrimFunc f) {
arith::Analyzer analyzer;
// Create an instance of the legalizer with the analyzer
LoopVectorizedLegalizer substituter(&analyzer);
// Get a mutable copy of the function node
PrimFuncNode* fptr = f.CopyOnWrite();
PrimFuncNode *fptr = f.CopyOnWrite();
// Apply the legalizer to the function body
fptr->body = substituter.VisitStmt(f->body);
return f;
}
private:
private:
// Constructor initializing the base class with the analyzer
LoopVectorizedLegalizer(arith::Analyzer* analyzer) : arith::IRMutatorWithAnalyzer(analyzer) {}
LoopVectorizedLegalizer(arith::Analyzer *analyzer)
: arith::IRMutatorWithAnalyzer(analyzer) {}
// Override the VisitStmt_ method to handle ForNode (loop statements)
Stmt VisitStmt_(const ForNode* op) final {
Stmt VisitStmt_(const ForNode *op) final {
// Visit and potentially modify the loop node
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
// If the loop is not vectorized, proceed with the default behavior
......@@ -90,5 +91,5 @@ tvm::transform::Pass LegalizeVectorizedLoop() {
TVM_REGISTER_GLOBAL("tl.transform.LegalizeVectorizedLoop")
.set_body_typed(LegalizeVectorizedLoop);
} // namespace tl
} // namespace tvm
} // namespace tl
} // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment