"sims/vscode:/vscode.git/clone" did not exist on "ff3367af6bdf119e23b260023bdeb8cd1a7a7c1d"
Commit cfc80c01 authored by ltqin's avatar ltqin
Browse files

Merge branch 'develop' into ck_conv_bww_fp16

parents 69ea9ad9 6d4450ef
# Compiled Object files
*.slo
*.lo
*.o
*.obj
# Precompiled Headers
*.gch
*.pch
*.ipch
# Compiled Dynamic libraries
*.so
*.dylib
*.dll
# Fortran module files
*.mod
# Compiled Static libraries
*.lai
*.la
*.a
*.lib
# Executables
*.exe
*.out
*.app
# vim tags
tags
.tags
.*.swp
# Editors
.vscode
# build-in-source directory
build*
# emacs temporary/backup files
.\#*
\#*\#
*~
# GDB temporary files
.gdb_history
\ No newline at end of file
......@@ -59,14 +59,19 @@
#define CK_USE_AMD_INNER_PRODUCT_INLINE_ASM 1
#endif
// AMD buffer addressing
#ifndef CK_USE_AMD_BUFFER_ADDRESSING
#define CK_USE_AMD_BUFFER_ADDRESSING 1
// AMD buffer_load
#ifndef CK_USE_AMD_BUFFER_LOAD
#define CK_USE_AMD_BUFFER_LOAD 1
#endif
// only gfx908 support native floating point atomic add
#ifndef CK_USE_AMD_BUFFER_ATOMIC_FADD
#define CK_USE_AMD_BUFFER_ATOMIC_FADD 0
// AMD buffer_store
#ifndef CK_USE_AMD_BUFFER_STORE
#define CK_USE_AMD_BUFFER_STORE 1
#endif
// AMD buffer_atomic_add
#ifndef CK_USE_AMD_BUFFER_ATOMIC_ADD
#define CK_USE_AMD_BUFFER_ATOMIC_ADD 1
#endif
// AMD XDLOPS
......@@ -97,9 +102,6 @@
#define CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE 1
#endif
// pass tensor descriptor by value or void*
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 1
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0
#define CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR 0
// merge transformation use magic number division
......@@ -167,6 +169,7 @@ enum ActivTypeEnum_t
// index type
using index_t = int32_t;
using long_index_t = int64_t;
} // namespace ck
#endif
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION3D_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION3D_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// A: in
// B: wei
// C: out
// GemmM = N * Do * Ho * Wo
// GemmN = K
// GemmK = Z * Y * X * C
template <typename... In,
typename... Wei,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk_pad(
const TensorDescriptor<In...>& in_grid_desc_n_di_hi_wi_c,
const TensorDescriptor<Wei...>& wei_k_z_y_x_c_grid_desc,
const TensorDescriptor<Out...>& out_n_do_ho_wo_k_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_grid_desc_n_di_hi_wi_c.GetLength(I0);
const auto K = out_n_do_ho_wo_k_grid_desc.GetLength(I4);
const auto C = in_grid_desc_n_di_hi_wi_c.GetLength(I4);
const auto Di = in_grid_desc_n_di_hi_wi_c.GetLength(I1);
const auto Hi = in_grid_desc_n_di_hi_wi_c.GetLength(I2);
const auto Wi = in_grid_desc_n_di_hi_wi_c.GetLength(I3);
const auto Do = out_n_do_ho_wo_k_grid_desc.GetLength(I1);
const auto Ho = out_n_do_ho_wo_k_grid_desc.GetLength(I2);
const auto Wo = out_n_do_ho_wo_k_grid_desc.GetLength(I3);
const auto Z = wei_k_z_y_x_c_grid_desc.GetLength(I1);
const auto Y = wei_k_z_y_x_c_grid_desc.GetLength(I2);
const auto X = wei_k_z_y_x_c_grid_desc.GetLength(I3);
const auto ConvStrideD = conv_strides[I0];
const auto ConvStrideH = conv_strides[I1];
const auto ConvStrideW = conv_strides[I2];
const auto ConvDilationD = conv_dilations[I0];
const auto ConvDilationH = conv_dilations[I1];
const auto ConvDilationW = conv_dilations[I2];
const auto InLeftPadD = in_left_pads[I0];
const auto InLeftPadH = in_left_pads[I1];
const auto InLeftPadW = in_left_pads[I2];
const auto InRightPadD = in_right_pads[I0];
const auto InRightPadH = in_right_pads[I1];
const auto InRightPadW = in_right_pads[I2];
const auto GemmM = N * Do * Ho * Wo;
const auto GemmN = K;
const auto GemmK = Z * Y * X * C;
const auto GemmK0 = GemmK / GemmK1;
// A: input tensor
const auto in_grid_desc_n_dip_hip_wip_c = transform_tensor_descriptor(
in_grid_desc_n_di_hi_wi_c,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_grid_desc_n_z_do_y_ho_x_wo_c = transform_tensor_descriptor(
in_grid_desc_n_dip_hip_wip_c,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5, 6>{}, Sequence<7>{}));
const auto in_grid_desc_gemmk_gemmm =
transform_tensor_descriptor(in_grid_desc_n_z_do_y_ho_x_wo_c,
make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)),
make_merge_transform(make_tuple(N, Do, Ho, Wo))),
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_grid_desc_gemmk0_gemmm_gemmk1 =
transform_tensor_descriptor(in_grid_desc_gemmk_gemmm,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// B: weight tensor
const auto wei_grid_desc_gemmk_gemmn = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Z * Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_grid_desc_gemmk0_gemmn_gemmk1 =
transform_tensor_descriptor(wei_grid_desc_gemmk_gemmn,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: output tensor
const auto out_grid_desc_gemmm_gemmn = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Do * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// const auto out_grid_desc_gemmm_gemmn = transform_tensor_descriptor(
// out_n_do_ho_wo_k_grid_desc,
// make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
// make_pass_through_transform(K)),
// make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(in_grid_desc_gemmk0_gemmm_gemmk1,
wei_grid_desc_gemmk0_gemmn_gemmk1,
out_grid_desc_gemmm_gemmn);
}
} // namespace ck
#endif
......@@ -1862,5 +1862,92 @@ struct Slice
}
};
/*
* \brief lower_idx = upper_idx % modulus.
* TODO: Need an improved implementation since the modulo operation is expensive.
*/
template <typename Modulus, typename UpLength>
struct Modulo
{
using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>;
using UpLengths = decltype(make_tuple(UpLength{}));
Modulus modulus_;
UpLengths up_lengths_;
__host__ __device__ constexpr Modulo() = default;
__host__ __device__ constexpr Modulo(const Modulus& modulus, const UpLength& up_length)
: modulus_{modulus}, up_lengths_{make_tuple(up_length)}
{
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
idx_low(Number<0>{}) = idx_up[Number<0>{}] % modulus_;
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx& up_idx,
Number<Hack>) const
{
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
constexpr auto I0 = Number<0>{};
const auto idx_low_old = idx_low;
idx_low(I0) = (up_idx(I0) + idx_diff_up(I0)) % modulus_;
idx_diff_low(I0) = idx_low - idx_low_old;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
template <typename UpIdx>
__host__ __device__ static constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
{
return true;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<UpLengths>::value;
}
__host__ __device__ void Print() const
{
printf("{");
printf("Modulus, ");
printf("up_lengths_");
print_multi_index(up_lengths_);
printf("}");
}
};
} // namespace ck
#endif
......@@ -98,6 +98,12 @@ __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_i
return Freeze<LowerIndex>{low_idx};
}
template <typename UpperIndex>
__host__ __device__ constexpr auto make_insert_transform(const UpperIndex& up_idx)
{
return Insert<UpperIndex>{up_idx};
}
template <typename LowLength, typename SliceBegin, typename SliceEnd>
__host__ __device__ constexpr auto make_slice_transform(const LowLength& low_length,
const SliceBegin& slice_begin,
......@@ -113,5 +119,11 @@ __host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& ve
return Vectorize<VectorSize, UpLength>{vector_size, up_length};
}
template <typename Modulus, typename UpLength>
__host__ __device__ constexpr auto make_modulo_transform(const Modulus& modulus,
const UpLength& up_length)
{
return Modulo<Modulus, UpLength>{modulus, up_length};
}
} // namespace ck
#endif
......@@ -307,6 +307,10 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
{
// sanity check
{
static_assert(NewTransforms::Size() == NewLowerDimensionOldVisibleIdss::Size() &&
NewTransforms::Size() == NewUpperDimensionNewVisibleIdss::Size(),
"wrong! inconsitent number of transform");
constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewLowerDimensionOldVisibleIdss{});
......
......@@ -17,7 +17,7 @@ template <index_t BlockSize,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t K1>
index_t KPack>
struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{
static constexpr auto I0 = Number<0>{};
......@@ -29,10 +29,15 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
static constexpr index_t KPerBlock =
BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
static constexpr index_t K0 = BK0NK1BlockDesc{}.GetLength(I0);
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{};
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
......@@ -66,7 +71,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(xdlops_a_idx[I0], 0, waveId_m, xdlops_a_idx[I1], 0);
return make_tuple(0, waveId_m, xdlops_a_idx[I1], Number<KPack>{} * xdlops_a_idx[I0]);
}
__device__ static auto CalculateBThreadOriginDataIndex()
......@@ -77,7 +82,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(xdlops_b_idx[I0], 0, waveId_n, xdlops_b_idx[I1], 0);
return make_tuple(0, waveId_n, xdlops_b_idx[I1], Number<KPack>{} * xdlops_b_idx[I0]);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
......@@ -115,12 +120,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
BK0NK1BlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(AK0MK1BlockDesc{}.GetLength(I0) == BK0NK1BlockDesc{}.GetLength(I0),
"wrong! K0 dimension not consistent");
static_assert(AK0MK1BlockDesc{}.GetLength(I2) == BK0NK1BlockDesc{}.GetLength(I2),
"wrong! K1 dimension not consistent");
static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n");
......@@ -219,32 +218,32 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1()
__host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K()
{
return transform_tensor_descriptor(
AK0MK1BlockDesc{},
make_tuple(make_pass_through_transform(Number<K0>{}),
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(Number<A_K0>{}, Number<A_K1>{})),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerXDL>{})),
make_pass_through_transform(Number<K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerXDL>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
__host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1()
__host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K()
{
return transform_tensor_descriptor(
BK0NK1BlockDesc{},
make_tuple(make_pass_through_transform(Number<K0>{}),
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(Number<B_K0>{}, Number<B_K1>{})),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerXDL>{})),
make_pass_through_transform(Number<K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerXDL>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1();
static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1();
static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K();
static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K();
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
......@@ -258,31 +257,31 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, I0, I0, I0),
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
make_tuple(I0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, I0, I0, I0),
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
make_tuple(I0, I0, I0, I0),
b_thread_buf);
static_for<0, K0, xdlops_gemm.K0PerXdlops>{}([&](auto k0) {
vector_type<FloatAB, K1> a_thread_vec;
vector_type<FloatAB, K1> b_thread_vec;
static_for<0, KPerBlock, KPack * xdlops_gemm.K0PerXdlops>{}([&](auto k) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, K1, 1>{}([&](auto i) {
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(k0, 0, 0, 0, i))>{}];
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(k0, 0, 0, 0, i))>{}];
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
});
using mfma_input_type =
......@@ -301,13 +300,13 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
}
private:
// A[K0, M0, M1, M2, K1]
// A[M0, M1, M2, KPerBlock]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<K0>{}, I1, I1, I1, Number<K1>{}));
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerBlock>{}));
// B[K0, N0, N1, N2, K1]
// B[N0, N1, N2, KPerBlock]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<K0>{}, I1, I1, I1, Number<K1>{}));
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerBlock>{}));
// C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
......@@ -315,23 +314,23 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<K0, 1, 1, 1, K1>,
Sequence<0, 1, 2, 3, 4>,
4,
K1,
K1>;
Sequence<1, 1, 1, KPerBlock>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<K0, 1, 1, 1, K1>,
Sequence<0, 1, 2, 3, 4>,
4,
K1,
K1>;
Sequence<1, 1, 1, KPerBlock>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
......
......@@ -33,7 +33,8 @@ template <index_t BlockSize,
index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector,
bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun>
bool ThreadTransferDstResetCoordinateAfterRun,
index_t NumThreadScratch = 1>
struct BlockwiseTensorSliceTransfer_v4r1
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
......@@ -86,45 +87,39 @@ struct BlockwiseTensorSliceTransfer_v4r1
}
}
template <typename SrcBuffer, typename SrcStepHacks>
__device__ void
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
}
}
template <typename SrcBuffer>
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
template <typename SrcBuffer, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_desc, src_buf);
threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id);
}
}
template <typename DstBuffer>
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
template <typename DstBuffer, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDesc& dst_desc,
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunWrite(dst_desc, dst_buf);
threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id);
}
}
template <typename SrcBuffer, typename DstBuffer>
template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id)
{
RunRead(src_desc, src_buf);
RunWrite(dst_desc, dst_buf);
RunRead(src_desc, src_buf, thread_scratch_id);
RunWrite(dst_desc, dst_buf, thread_scratch_id);
}
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
......@@ -136,21 +131,6 @@ struct BlockwiseTensorSliceTransfer_v4r1
}
}
// SrcMoveSliceWindowStepHack to control index calculation move slice window
template <typename SrcMoveSliceWindowStepHack>
__device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& step,
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(
src_desc, step, src_move_slice_window_step_hack);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
......@@ -182,7 +162,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
SrcScalarStrideInVector,
DstScalarStrideInVector,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>;
ThreadTransferDstResetCoordinateAfterRun,
NumThreadScratch>;
ThreadwiseTransfer threadwise_transfer_;
};
......
#ifndef CK_ELEMENT_WISE_OPERATION_HPP
#define CK_ELEMENT_WISE_OPERATION_HPP
#include "data_type.hpp"
namespace ck {
namespace tensor_operation {
namespace element_wise {
......
......@@ -11,7 +11,6 @@
namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseBatchedGemm,
typename FloatAB,
typename FloatC,
......@@ -53,64 +52,6 @@ __global__ void
c_element_op,
block_2_ctile_map);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
template <typename GridwiseBatchedGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_G_K0_M_K1,
typename BGridDesc_G_K0_N_K1,
typename CGridDesc_G_M0_N0_M1_N1_M2_M3_M4_N2,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_xdlops_v2r3(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_grid_desc_g_k0_m_k1,
const void CONSTANT* p_b_grid_desc_g_k0_n_k1,
const void CONSTANT* p_c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2,
const void CONSTANT* p_a_element_op,
const void CONSTANT* p_b_element_op,
const void CONSTANT* p_c_element_op,
const void CONSTANT* p_block_2_ctile_map)
{
const auto a_grid_desc_g_k0_m_k1 = *reinterpret_cast<const AGridDesc_G_K0_M_K1*>(
cast_pointer_to_generic_address_space(p_a_grid_desc_g_k0_m_k1));
const auto b_grid_desc_g_k0_n_k1 = *reinterpret_cast<const BGridDesc_G_K0_N_K1*>(
cast_pointer_to_generic_address_space(p_b_grid_desc_g_k0_n_k1));
const auto c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2 =
*reinterpret_cast<const CGridDesc_G_M0_N0_M1_N1_M2_M3_M4_N2*>(
cast_pointer_to_generic_address_space(p_c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2));
const auto block_2_ctile_map = *reinterpret_cast<const Block2CTileMap*>(
cast_pointer_to_generic_address_space(p_block_2_ctile_map));
const auto a_element_op = *reinterpret_cast<const AElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_a_element_op));
const auto b_element_op = *reinterpret_cast<const BElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_b_element_op));
const auto c_element_op = *reinterpret_cast<const CElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_c_element_op));
__shared__ char p_shared[GridwiseBatchedGemm::GetSharedMemoryNumberOfByte()];
GridwiseBatchedGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
a_grid_desc_g_k0_m_k1,
b_grid_desc_g_k0_n_k1,
c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
}
#endif
template <index_t BlockSize,
typename FloatAB,
......@@ -391,7 +332,7 @@ struct GridwiseBatchedGemm_gk0mk1_gk0nk1_gmn_xdlops_v2r3
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeBlock2CTileMap(const CGridDesc_G_M_N& c_grid_desc_g_m_n, index_t M01, index_t N01)
MakeDefaultBlock2CTileMap(const CGridDesc_G_M_N& c_grid_desc_g_m_n, index_t M01, index_t N01)
{
const auto G = c_grid_desc_g_m_n.GetLength(I0);
const auto M = c_grid_desc_g_m_n.GetLength(I1);
......@@ -414,24 +355,24 @@ struct GridwiseBatchedGemm_gk0mk1_gk0nk1_gmn_xdlops_v2r3
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
const auto c_blockid_to_g_m00_m01_n00_n01_block_cluster_adaptor =
const auto cblockid_to_g_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(G, M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto c_blockid_to_g_m0_n0_block_cluster_adaptor =
const auto cblockid_to_g_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
c_blockid_to_g_m00_m01_n00_n01_block_cluster_adaptor);
cblockid_to_g_m00_m01_n00_n01_block_cluster_adaptor);
return c_blockid_to_g_m0_n0_block_cluster_adaptor;
return cblockid_to_g_m0_n0_block_cluster_adaptor;
}
using CGridDesc_G_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_G_M_N{}));
using Block2CTileMap = decltype(MakeBlock2CTileMap(CGridDesc_G_M_N{}, 1, 1));
using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_G_M_N{}, 1, 1));
template <bool HasMainKBlockLoop>
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
......
......@@ -12,7 +12,6 @@
namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
......@@ -33,7 +32,7 @@ __global__ void
const AKM0M1GridDesc a_k_m0_m1_grid_desc,
const BKN0N1GridDesc b_k_n0_n1_grid_desc,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockIdToM0N0BlockClusterAdaptor c_blockid_to_m0_n0_block_cluster_adaptor)
const CBlockIdToM0N0BlockClusterAdaptor cblockid_to_m0_n0_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
......@@ -47,66 +46,10 @@ __global__ void
a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor,
cblockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by CONSTANT void pointer
// CONSTANT is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AKM0M1GridDesc,
typename BKN0N1GridDesc,
typename CM0M10M11N0N10N11GridDesc,
typename CBlockIdToM0N0BlockClusterAdaptor,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dlops_v1r2(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_k_m0_m1_grid_desc,
const void CONSTANT* p_b_k_n0_n1_grid_desc,
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
{
// first cast void CONSTANT void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_k_m0_m1_grid_desc = *reinterpret_cast<const AKM0M1GridDesc*>(
cast_pointer_to_generic_address_space(p_a_k_m0_m1_grid_desc));
const auto b_k_n0_n1_grid_desc = *reinterpret_cast<const BKN0N1GridDesc*>(
cast_pointer_to_generic_address_space(p_b_k_n0_n1_grid_desc));
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
cast_pointer_to_generic_address_space(p_c_m0_m10_m11_n0_n10_n11_grid_desc));
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
cast_pointer_to_generic_address_space(p_c_blockid_to_m0_n0_block_cluster_adaptor));
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#endif
template <index_t BlockSize,
typename FloatAB,
......@@ -298,12 +241,12 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
const auto cblockid_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return c_blockid_to_m0_n0_block_cluster_adaptor;
return cblockid_to_m0_n0_block_cluster_adaptor;
}
using AKM0M1GridDesc = decltype(MakeAKM0M1GridDescriptor(AKMGridDesc{}));
......@@ -321,7 +264,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
const AKM0M1GridDesc& a_k_m0_m1_grid_desc,
const BKN0N1GridDesc& b_k_n0_n1_grid_desc,
const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockIdToM0N0BlockClusterAdaptor& c_blockid_to_m0_n0_block_cluster_adaptor,
const CBlockIdToM0N0BlockClusterAdaptor& cblockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
......@@ -336,7 +279,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
// divide block work by [M, N]
const auto c_m0_n0_block_cluster_idx =
c_blockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex(
cblockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex(
make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR
......
......@@ -12,7 +12,6 @@
namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
......@@ -33,7 +32,7 @@ __global__ void
const AK0M0M1K1GridDesc a_k0_m0_m1_k1_grid_desc,
const BK0N0N1K1GridDesc b_k0_n0_n1_k1_grid_desc,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockIdToM0N0BlockClusterAdaptor c_blockid_to_m0_n0_block_cluster_adaptor)
const CBlockIdToM0N0BlockClusterAdaptor cblockid_to_m0_n0_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
......@@ -47,66 +46,10 @@ __global__ void
a_k0_m0_m1_k1_grid_desc,
b_k0_n0_n1_k1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor,
cblockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by CONSTANT void pointer
// CONSTANT is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AK0M0M1K1GridDesc,
typename BK0N0N1K1GridDesc,
typename CM0M10M11N0N10N11GridDesc,
typename CBlockIdToM0N0BlockClusterAdaptor,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dlops_v1r3(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_k0_m0_m1_k1_grid_desc,
const void CONSTANT* p_b_k0_n0_n1_k1_grid_desc,
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
{
// first cast void CONSTANT void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_k0_m0_m1_k1_grid_desc = *reinterpret_cast<const AK0M0M1K1GridDesc*>(
cast_pointer_to_generic_address_space(p_a_k0_m0_m1_k1_grid_desc));
const auto b_k0_n0_n1_k1_grid_desc = *reinterpret_cast<const BK0N0N1K1GridDesc*>(
cast_pointer_to_generic_address_space(p_b_k0_n0_n1_k1_grid_desc));
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
cast_pointer_to_generic_address_space(p_c_m0_m10_m11_n0_n10_n11_grid_desc));
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
cast_pointer_to_generic_address_space(p_c_blockid_to_m0_n0_block_cluster_adaptor));
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k0_m0_m1_k1_grid_desc,
b_k0_n0_n1_k1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#endif
template <index_t BlockSize,
typename FloatAB,
......@@ -305,12 +248,12 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
const auto cblockid_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return c_blockid_to_m0_n0_block_cluster_adaptor;
return cblockid_to_m0_n0_block_cluster_adaptor;
}
using AK0M0M1K1GridDesc = decltype(MakeAK0M0M1K1GridDescriptor(AK0MK1GridDesc{}));
......@@ -328,7 +271,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
const AK0M0M1K1GridDesc& a_k0_m0_m1_k1_grid_desc,
const BK0N0N1K1GridDesc& b_k0_n0_n1_k1_grid_desc,
const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockIdToM0N0BlockClusterAdaptor& c_blockid_to_m0_n0_block_cluster_adaptor,
const CBlockIdToM0N0BlockClusterAdaptor& cblockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
......@@ -341,7 +284,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// divide block work by [M, N]
const auto c_m0_n0_block_cluster_idx =
c_blockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex(
cblockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex(
make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR
......
#ifndef CK_GRIDWISE_GEMM_PIPELINE_V1_HPP
#define CK_GRIDWISE_GEMM_PIPELINE_V1_HPP
#include "common_header.hpp"
namespace ck {
template <typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer,
index_t NumPrefetch,
bool HasMainLoop>
struct GridwiseGemmPipeline_v1;
// 1-stage prefetch
template <typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer,
bool HasMainLoop>
struct GridwiseGemmPipeline_v1<AGridDesc,
ABlockDesc,
ABlockTransfer,
AGridBuffer,
ABlockBuffer,
ABlockTransferStep,
BGridDesc,
BBlockDesc,
BBlockTransfer,
BGridBuffer,
BBlockBuffer,
BBlockTransferStep,
BlockwiseGemm,
CThreadBuffer,
1,
HasMainLoop>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static __device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
#if 0
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
// Initialize C
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
#else
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
#endif
}
};
// 2-stage prefetch
template <typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer,
bool HasMainLoop>
struct GridwiseGemmPipeline_v1<AGridDesc,
ABlockDesc,
ABlockTransfer,
AGridBuffer,
ABlockBuffer,
ABlockTransferStep,
BGridDesc,
BBlockDesc,
BBlockTransfer,
BGridBuffer,
BBlockBuffer,
BBlockTransferStep,
BlockwiseGemm,
CThreadBuffer,
2,
HasMainLoop>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static __device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
// preload data into LDS
{
// Read 0
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
// Move
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Read 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
}
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
// Move
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Write i
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
// Read i+2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
// Sync
block_sync_lds();
// Gemm i
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
// Sync
block_sync_lds();
// Move
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Write i+1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1);
// Read i+3
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
// Sync
block_sync_lds();
// Gemm i+1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
// Sync
block_sync_lds();
i += 2;
} while(i < (num_loop - 2));
}
// tail
{
// Write num_loop - 2
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
// Sync
block_sync_lds();
// Gemm num_loop - 2
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
// Sync
block_sync_lds();
// Write num_loop - 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1);
// Sync
block_sync_lds();
// Gemm num_loop - 1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
};
} // namespace ck
#endif
......@@ -8,10 +8,10 @@
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
......@@ -22,7 +22,7 @@ template <typename GridwiseGemm,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap,
bool HasMainKBlockLoop>
bool HasMainK0BlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
......@@ -41,7 +41,7 @@ __global__ void
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
......@@ -53,63 +53,6 @@ __global__ void
c_element_op,
block_2_ctile_map);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_grid_desc_k0_m_k1,
const void CONSTANT* p_b_grid_desc_k0_n_k1,
const void CONSTANT* p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const void CONSTANT* p_a_element_op,
const void CONSTANT* p_b_element_op,
const void CONSTANT* p_c_element_op,
const void CONSTANT* p_block_2_ctile_map)
{
const auto a_grid_desc_k0_m_k1 = *reinterpret_cast<const AGridDesc_K0_M_K1*>(
cast_pointer_to_generic_address_space(p_a_grid_desc_k0_m_k1));
const auto b_grid_desc_k0_n_k1 = *reinterpret_cast<const BGridDesc_K0_N_K1*>(
cast_pointer_to_generic_address_space(p_b_grid_desc_k0_n_k1));
const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
*reinterpret_cast<const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2*>(
cast_pointer_to_generic_address_space(p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2));
const auto block_2_ctile_map = *reinterpret_cast<const Block2CTileMap*>(
cast_pointer_to_generic_address_space(p_block_2_ctile_map));
const auto a_element_op = *reinterpret_cast<const AElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_a_element_op));
const auto b_element_op = *reinterpret_cast<const BElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_b_element_op));
const auto c_element_op = *reinterpret_cast<const CElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_c_element_op));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
}
#endif
template <index_t BlockSize,
typename FloatAB,
......@@ -148,7 +91,8 @@ template <index_t BlockSize,
bool BBlockLdsExtraN,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
index_t CThreadTransferDstScalarPerVector,
index_t NumPrefetch = 1>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{
static constexpr auto I0 = Number<0>{};
......@@ -252,6 +196,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false;
// check NumPrefetch
if constexpr(NumPrefetch == 1)
{
// 1-stage prefetch always supported
}
else if constexpr(NumPrefetch == 2)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if(!((K0 / K0PerBlock) % 2 == 0))
{
return false;
}
}
else
{
return false;
}
// check M01, N01
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
......@@ -277,9 +240,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return grid_size;
}
// TODO move this function into GEMM-pipeline class
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{
const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1;
const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1;
return has_main_k0_block_loop;
}
......@@ -336,7 +300,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
......@@ -357,24 +321,24 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor =
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor);
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return c_blockid_to_m0_n0_block_cluster_adaptor;
return cblockid_to_m0_n0_block_cluster_adaptor;
}
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using Block2CTileMap = decltype(MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
template <bool HasMainKBlockLoop>
template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
......@@ -439,7 +403,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
true,
NumPrefetch>(
a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
......@@ -469,7 +434,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
true,
NumPrefetch>(
b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
......@@ -513,51 +479,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
// preload data into LDS
{
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf);
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf);
}
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainKBlockLoop)
{
index_t k0_block_data_begin = 0;
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step);
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf);
k0_block_data_begin += K0PerBlock;
} while(k0_block_data_begin < (K0 - K0PerBlock));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_k0_m_k1)>,
remove_cvref_t<decltype(a_block_desc_k0_m_k1)>,
remove_cvref_t<decltype(a_blockwise_copy)>,
remove_cvref_t<decltype(a_grid_buf)>,
remove_cvref_t<decltype(a_block_buf)>,
remove_cvref_t<decltype(a_block_slice_copy_step)>,
remove_cvref_t<decltype(b_grid_desc_k0_n_k1)>,
remove_cvref_t<decltype(b_block_desc_k0_n_k1)>,
remove_cvref_t<decltype(b_blockwise_copy)>,
remove_cvref_t<decltype(b_grid_buf)>,
remove_cvref_t<decltype(b_block_buf)>,
remove_cvref_t<decltype(b_block_slice_copy_step)>,
remove_cvref_t<decltype(blockwise_gemm)>,
remove_cvref_t<decltype(c_thread_buf)>,
NumPrefetch,
HasMainK0BlockLoop>{};
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1,
a_block_desc_k0_m_k1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_k0_n_k1,
b_block_desc_k0_n_k1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
K0BlockMainLoop);
// output: register to global memory
{
......
......@@ -11,7 +11,6 @@
namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
......@@ -55,67 +54,6 @@ __global__ void
c_element_op,
c_block_cluster_adaptor);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename ABK0MK1GridDesc,
typename BBK0NK1GridDesc,
typename CM0N0M1N1M2M3M4N2GridDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_b_k0_m_k1_grid_desc,
const void CONSTANT* p_b_b_k0_n_k1_grid_desc,
const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
const void CONSTANT* p_a_element_op,
const void CONSTANT* p_b_element_op,
const void CONSTANT* p_c_element_op,
const void CONSTANT* p_block_2_ctile_map)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
const auto a_b_k0_m_k1_grid_desc = *reinterpret_cast<const ABK0MK1GridDesc*>(
cast_pointer_to_generic_address_space(p_a_b_k0_m_k1_grid_desc));
const auto b_b_k0_n_k1_grid_desc = *reinterpret_cast<const BBK0NK1GridDesc*>(
cast_pointer_to_generic_address_space(p_b_b_k0_n_k1_grid_desc));
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
*reinterpret_cast<const CM0N0M1N1M2M3M4N2GridDesc*>(
cast_pointer_to_generic_address_space(p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc));
const auto block_2_ctile_map = *reinterpret_cast<const Block2CTileMap*>(
cast_pointer_to_generic_address_space(p_block_2_ctile_map));
const auto a_element_op = *reinterpret_cast<const AElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_a_element_op));
const auto b_element_op = *reinterpret_cast<const BElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_b_element_op));
const auto c_element_op = *reinterpret_cast<const CElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_c_element_op));
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
}
#endif
template <index_t BlockSize,
typename FloatAB,
......@@ -349,17 +287,17 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
const auto c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor =
const auto cblockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(KBatch, M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto c_blockid_to_kbatch_m0_n0_block_cluster_adaptor =
const auto cblockid_to_kbatch_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor);
cblockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor);
return c_blockid_to_kbatch_m0_n0_block_cluster_adaptor;
return cblockid_to_kbatch_m0_n0_block_cluster_adaptor;
}
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
......
......@@ -9,6 +9,7 @@
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "blockwise_tensor_slice_transfer_v6r2.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
namespace ck {
......@@ -23,7 +24,7 @@ template <typename GridwiseGemm,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap,
bool HasMainKBlockLoop>
bool HasMainK0BlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
......@@ -46,7 +47,7 @@ __global__ void
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(
GridwiseGemm::template Run<HasMainK0BlockLoop>(
p_a_grid,
p_b_grid,
p_c_grid,
......@@ -102,7 +103,8 @@ template <
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
index_t NumPrefetch = 1>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
{
static constexpr auto I0 = Number<0>{};
......@@ -235,6 +237,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false;
// check NumPrefetch
if constexpr(NumPrefetch == 1)
{
// 1-stage prefetch always supported
}
else if constexpr(NumPrefetch == 2)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if(!((K0 / K0PerBlock) % 2 == 0))
{
return false;
}
}
else
{
return false;
}
// check M01, N01
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
......@@ -260,9 +281,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
return grid_size;
}
// TODO move this function into GEMM-pipeline class
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{
const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1;
const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1;
return has_main_k0_block_loop;
}
......@@ -296,7 +318,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
......@@ -317,17 +339,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor =
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor);
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return c_blockid_to_m0_n0_block_cluster_adaptor;
return cblockid_to_m0_n0_block_cluster_adaptor;
}
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype(
......@@ -339,9 +361,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
C0GridDesc_M_N{}))>;
using Block2CTileMap = remove_cvref_t<decltype(MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
template <bool HasMainKBlockLoop>
template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
......@@ -416,7 +439,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
true,
NumPrefetch>(
a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
......@@ -446,7 +470,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
true,
NumPrefetch>(
b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
......@@ -490,51 +515,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
// preload data into LDS
{
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf);
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf);
}
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainKBlockLoop)
{
index_t k0_block_data_begin = 0;
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step);
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf);
k0_block_data_begin += K0PerBlock;
} while(k0_block_data_begin < (K0 - K0PerBlock));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_k0_m_k1)>,
remove_cvref_t<decltype(a_block_desc_k0_m_k1)>,
remove_cvref_t<decltype(a_blockwise_copy)>,
remove_cvref_t<decltype(a_grid_buf)>,
remove_cvref_t<decltype(a_block_buf)>,
remove_cvref_t<decltype(a_block_slice_copy_step)>,
remove_cvref_t<decltype(b_grid_desc_k0_n_k1)>,
remove_cvref_t<decltype(b_block_desc_k0_n_k1)>,
remove_cvref_t<decltype(b_blockwise_copy)>,
remove_cvref_t<decltype(b_grid_buf)>,
remove_cvref_t<decltype(b_block_buf)>,
remove_cvref_t<decltype(b_block_slice_copy_step)>,
remove_cvref_t<decltype(blockwise_gemm)>,
remove_cvref_t<decltype(c_thread_buf)>,
NumPrefetch,
HasMainK0BlockLoop>{};
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1,
a_block_desc_k0_m_k1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_k0_n_k1,
b_block_desc_k0_n_k1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
K0BlockMainLoop);
// shuffle C and write out
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment