Commit ddb0c230 authored by turneram's avatar turneram
Browse files

Formatting

parent 127393f4
...@@ -48,12 +48,11 @@ static constexpr auto I3 = ck::Number<3>{}; ...@@ -48,12 +48,11 @@ static constexpr auto I3 = ck::Number<3>{};
static constexpr auto I4 = ck::Number<4>{}; static constexpr auto I4 = ck::Number<4>{};
static constexpr auto I5 = ck::Number<5>{}; static constexpr auto I5 = ck::Number<5>{};
static constexpr ck::index_t K1 = 1; static constexpr ck::index_t K1 = 1;
static constexpr auto K1Number = ck::Number<K1>{}; static constexpr auto K1Number = ck::Number<K1>{};
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using ALayout = Col; using ALayout = Col;
using BLayout = Row; using BLayout = Row;
using CLayout = Row; using CLayout = Row;
...@@ -69,34 +68,32 @@ template <ck::index_t... Is> ...@@ -69,34 +68,32 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
// Values hard-coded by CK // Values hard-coded by CK
static constexpr ck::index_t MPerBlock = 128; static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t NPerBlock = 128; static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t BlockSize = 256; static constexpr ck::index_t BlockSize = 256;
static constexpr ck::index_t K0PerBlock = 16; static constexpr ck::index_t K0PerBlock = 16;
static constexpr ck::index_t M1PerThread = 4; static constexpr ck::index_t M1PerThread = 4;
static constexpr ck::index_t N1PerThread = 4; static constexpr ck::index_t N1PerThread = 4;
static constexpr ck::index_t KPerThread = 1; static constexpr ck::index_t KPerThread = 1;
using M1N1ThreadClusterM1Xs = S<8, 2>; using M1N1ThreadClusterM1Xs = S<8, 2>;
using M1N1ThreadClusterN1Xs = S<8, 2>; using M1N1ThreadClusterN1Xs = S<8, 2>;
using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 = S<2, 1, 4, 1>; using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 = S<2, 1, 4, 1>;
using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 = S<8, 1, 32, 1>; using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 = S<8, 1, 32, 1>;
using ABlockTransferThreadClusterArrangeOrder = S<0, 3, 1, 2>; using ABlockTransferThreadClusterArrangeOrder = S<0, 3, 1, 2>;
using ABlockTransferSrcAccessOrder = S<0, 3, 1, 2>; using ABlockTransferSrcAccessOrder = S<0, 3, 1, 2>;
using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = S<1, 1, 4, 1>; using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = S<1, 1, 4, 1>;
using ABlockTransferSrcVectorTensorContiguousDimOrder = S<0, 3, 1, 2>; using ABlockTransferSrcVectorTensorContiguousDimOrder = S<0, 3, 1, 2>;
using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = S<1, 1, 4, 1>; using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = S<1, 1, 4, 1>;
using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 = S<2, 1, 4, 1>; using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 = S<2, 1, 4, 1>;
using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 = S<8, 1, 32, 1>; using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 = S<8, 1, 32, 1>;
using BBlockTransferThreadClusterArrangeOrder = S<0, 3, 1, 2>; using BBlockTransferThreadClusterArrangeOrder = S<0, 3, 1, 2>;
using BBlockTransferSrcAccessOrder = S<0, 3, 1, 2>; using BBlockTransferSrcAccessOrder = S<0, 3, 1, 2>;
using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = S<1, 1, 4, 1>; using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = S<1, 1, 4, 1>;
using BBlockTransferSrcVectorTensorContiguousDimOrder = S<0, 3, 1, 2>; using BBlockTransferSrcVectorTensorContiguousDimOrder = S<0, 3, 1, 2>;
using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = S<1, 1, 4, 1>; using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = S<1, 1, 4, 1>;
using CThreadTransferSrcDstAccessOrder = S<0, 1, 2, 3, 4, 5>; using CThreadTransferSrcDstAccessOrder = S<0, 1, 2, 3, 4, 5>;
static constexpr ck::index_t CThreadTransferSrcDstVectorDim = 5; static constexpr ck::index_t CThreadTransferSrcDstVectorDim = 5;
static constexpr ck::index_t CThreadTransferDstScalarPerVector = 4; static constexpr ck::index_t CThreadTransferDstScalarPerVector = 4;
static constexpr auto MakeAGridDescriptor_K0_M_K1(ck::index_t M, ck::index_t K, ck::index_t StrideA) static constexpr auto MakeAGridDescriptor_K0_M_K1(ck::index_t M, ck::index_t K, ck::index_t StrideA)
{ {
...@@ -122,7 +119,7 @@ static constexpr auto MakeAGridDescriptor_K0_M_K1(ck::index_t M, ck::index_t K, ...@@ -122,7 +119,7 @@ static constexpr auto MakeAGridDescriptor_K0_M_K1(ck::index_t M, ck::index_t K,
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_k, a_grid_desc_m_k,
ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(K0, K1Number)), ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(K0, K1Number)),
ck::make_right_pad_transform(M, PadM)), ck::make_right_pad_transform(M, PadM)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}), ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
} }
...@@ -131,7 +128,7 @@ static constexpr auto MakeAGridDescriptor_K0_M_K1(ck::index_t M, ck::index_t K, ...@@ -131,7 +128,7 @@ static constexpr auto MakeAGridDescriptor_K0_M_K1(ck::index_t M, ck::index_t K,
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_k, a_grid_desc_m_k,
ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(K0, K1Number)), ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(K0, K1Number)),
ck::make_pass_through_transform(M)), ck::make_pass_through_transform(M)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}), ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
} }
...@@ -161,7 +158,7 @@ static constexpr auto MakeBGridDescriptor_K0_N_K1(ck::index_t K, ck::index_t N, ...@@ -161,7 +158,7 @@ static constexpr auto MakeBGridDescriptor_K0_N_K1(ck::index_t K, ck::index_t N,
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_k_n, b_grid_desc_k_n,
ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(K0, K1Number)), ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(K0, K1Number)),
ck::make_right_pad_transform(N, PadN)), ck::make_right_pad_transform(N, PadN)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}), ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
} }
...@@ -170,7 +167,7 @@ static constexpr auto MakeBGridDescriptor_K0_N_K1(ck::index_t K, ck::index_t N, ...@@ -170,7 +167,7 @@ static constexpr auto MakeBGridDescriptor_K0_N_K1(ck::index_t K, ck::index_t N,
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_k_n, b_grid_desc_k_n,
ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(K0, K1Number)), ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(K0, K1Number)),
ck::make_pass_through_transform(N)), ck::make_pass_through_transform(N)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}), ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
} }
...@@ -194,11 +191,11 @@ static constexpr auto MakeCGridDescriptor_M_N(ck::index_t M, ck::index_t N, ck:: ...@@ -194,11 +191,11 @@ static constexpr auto MakeCGridDescriptor_M_N(ck::index_t M, ck::index_t N, ck::
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor( return transform_tensor_descriptor(c_grid_desc_m_n,
c_grid_desc_m_n, ck::make_tuple(ck::make_right_pad_transform(M, PadM),
ck::make_tuple(ck::make_right_pad_transform(M, PadM), ck::make_right_pad_transform(N, PadN)), ck::make_right_pad_transform(N, PadN)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}), ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
} }
else else
{ {
...@@ -229,48 +226,50 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t) ...@@ -229,48 +226,50 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
constexpr auto bs = bstrides[0]; constexpr auto bs = bstrides[0];
constexpr auto cstrides = get_shape_c<V>{}.strides; constexpr auto cstrides = get_shape_c<V>{}.strides;
constexpr auto cs = cstrides[0]; constexpr auto cs = cstrides[0];
auto idx = make_index(); auto idx = make_index();
if (idx.global == 0) if(idx.global == 0)
printf("%i %i %i, %i %i %i\n", int(m), int(n), int(k), int(as), int(bs), int(cs)); printf("%i %i %i, %i %i %i\n", int(m), int(n), int(k), int(as), int(bs), int(cs));
auto a_grid_desc_k0_m_k1 = MakeAGridDescriptor_K0_M_K1(static_cast<ck::index_t>(m), static_cast<ck::index_t>(k), static_cast<ck::index_t>(as)); auto a_grid_desc_k0_m_k1 = MakeAGridDescriptor_K0_M_K1(
auto b_grid_desc_k0_n_k1 = MakeBGridDescriptor_K0_N_K1(static_cast<ck::index_t>(k), static_cast<ck::index_t>(n), static_cast<ck::index_t>(bs)); static_cast<ck::index_t>(m), static_cast<ck::index_t>(k), static_cast<ck::index_t>(as));
auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(static_cast<ck::index_t>(m), static_cast<ck::index_t>(n), static_cast<ck::index_t>(cs)); auto b_grid_desc_k0_n_k1 = MakeBGridDescriptor_K0_N_K1(
static_cast<ck::index_t>(k), static_cast<ck::index_t>(n), static_cast<ck::index_t>(bs));
auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
static_cast<ck::index_t>(m), static_cast<ck::index_t>(n), static_cast<ck::index_t>(cs));
using GridwiseGemm = using GridwiseGemm =
ck::GridwiseGemmDl_km_kn_mn_v1r3<BlockSize, ck::GridwiseGemmDl_km_kn_mn_v1r3<BlockSize,
ADataType, ADataType,
AccDataType, AccDataType,
CDataType, CDataType,
ck::InMemoryDataOperationEnum::Set, ck::InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
K0PerBlock, K0PerBlock,
M1PerThread, M1PerThread,
N1PerThread, N1PerThread,
KPerThread, KPerThread,
M1N1ThreadClusterM1Xs, M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs, M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>; CThreadTransferDstScalarPerVector>;
auto a_grid_desc_k0_m0_m1_k1 = auto a_grid_desc_k0_m0_m1_k1 =
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1); GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1);
...@@ -280,10 +279,18 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t) ...@@ -280,10 +279,18 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n); GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n);
auto block_2_ctile_map = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n); auto block_2_ctile_map = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n);
constexpr bool HasMainKBlockLoop = true;
constexpr bool HasMainKBlockLoop = true;
constexpr bool HasDoubleTailKBlockLoop = true; constexpr bool HasDoubleTailKBlockLoop = true;
GridwiseGemm::Run(a_t.data(), b_t.data(), c_t.data(), p_t.data(), a_grid_desc_k0_m0_m1_k1, b_grid_desc_k0_n0_n1_k1, c_grid_desc_m0_m10_m11_n0_n10_n11, block_2_ctile_map, ck::integral_constant<bool, HasMainKBlockLoop>{}, ck::integral_constant<bool, HasDoubleTailKBlockLoop>{}); GridwiseGemm::Run(a_t.data(),
b_t.data(),
c_t.data(),
p_t.data(),
a_grid_desc_k0_m0_m1_k1,
b_grid_desc_k0_n0_n1_k1,
c_grid_desc_m0_m10_m11_n0_n10_n11,
block_2_ctile_map,
ck::integral_constant<bool, HasMainKBlockLoop>{},
ck::integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
} // namespace migraphx } // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment