Commit ddb0c230 authored by turneram's avatar turneram
Browse files

Formatting

parent 127393f4
...@@ -48,7 +48,6 @@ static constexpr auto I3 = ck::Number<3>{}; ...@@ -48,7 +48,6 @@ static constexpr auto I3 = ck::Number<3>{};
static constexpr auto I4 = ck::Number<4>{}; static constexpr auto I4 = ck::Number<4>{};
static constexpr auto I5 = ck::Number<5>{}; static constexpr auto I5 = ck::Number<5>{};
static constexpr ck::index_t K1 = 1; static constexpr ck::index_t K1 = 1;
static constexpr auto K1Number = ck::Number<K1>{}; static constexpr auto K1Number = ck::Number<K1>{};
...@@ -96,8 +95,6 @@ using CThreadTransferSrcDstAccessOrder = S<0, 1, 2, 3, 4, 5>; ...@@ -96,8 +95,6 @@ using CThreadTransferSrcDstAccessOrder = S<0, 1, 2, 3, 4, 5>;
static constexpr ck::index_t CThreadTransferSrcDstVectorDim = 5; static constexpr ck::index_t CThreadTransferSrcDstVectorDim = 5;
static constexpr ck::index_t CThreadTransferDstScalarPerVector = 4; static constexpr ck::index_t CThreadTransferDstScalarPerVector = 4;
static constexpr auto MakeAGridDescriptor_K0_M_K1(ck::index_t M, ck::index_t K, ck::index_t StrideA) static constexpr auto MakeAGridDescriptor_K0_M_K1(ck::index_t M, ck::index_t K, ck::index_t StrideA)
{ {
assert(K % K1 == 0); assert(K % K1 == 0);
...@@ -194,9 +191,9 @@ static constexpr auto MakeCGridDescriptor_M_N(ck::index_t M, ck::index_t N, ck:: ...@@ -194,9 +191,9 @@ static constexpr auto MakeCGridDescriptor_M_N(ck::index_t M, ck::index_t N, ck::
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor( return transform_tensor_descriptor(c_grid_desc_m_n,
c_grid_desc_m_n, ck::make_tuple(ck::make_right_pad_transform(M, PadM),
ck::make_tuple(ck::make_right_pad_transform(M, PadM), ck::make_right_pad_transform(N, PadN)), ck::make_right_pad_transform(N, PadN)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}), ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
} }
...@@ -230,12 +227,15 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t) ...@@ -230,12 +227,15 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
constexpr auto cstrides = get_shape_c<V>{}.strides; constexpr auto cstrides = get_shape_c<V>{}.strides;
constexpr auto cs = cstrides[0]; constexpr auto cs = cstrides[0];
auto idx = make_index(); auto idx = make_index();
if (idx.global == 0) if(idx.global == 0)
printf("%i %i %i, %i %i %i\n", int(m), int(n), int(k), int(as), int(bs), int(cs)); printf("%i %i %i, %i %i %i\n", int(m), int(n), int(k), int(as), int(bs), int(cs));
auto a_grid_desc_k0_m_k1 = MakeAGridDescriptor_K0_M_K1(static_cast<ck::index_t>(m), static_cast<ck::index_t>(k), static_cast<ck::index_t>(as)); auto a_grid_desc_k0_m_k1 = MakeAGridDescriptor_K0_M_K1(
auto b_grid_desc_k0_n_k1 = MakeBGridDescriptor_K0_N_K1(static_cast<ck::index_t>(k), static_cast<ck::index_t>(n), static_cast<ck::index_t>(bs)); static_cast<ck::index_t>(m), static_cast<ck::index_t>(k), static_cast<ck::index_t>(as));
auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(static_cast<ck::index_t>(m), static_cast<ck::index_t>(n), static_cast<ck::index_t>(cs)); auto b_grid_desc_k0_n_k1 = MakeBGridDescriptor_K0_N_K1(
static_cast<ck::index_t>(k), static_cast<ck::index_t>(n), static_cast<ck::index_t>(bs));
auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
static_cast<ck::index_t>(m), static_cast<ck::index_t>(n), static_cast<ck::index_t>(cs));
using GridwiseGemm = using GridwiseGemm =
ck::GridwiseGemmDl_km_kn_mn_v1r3<BlockSize, ck::GridwiseGemmDl_km_kn_mn_v1r3<BlockSize,
ADataType, ADataType,
...@@ -271,7 +271,6 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t) ...@@ -271,7 +271,6 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>; CThreadTransferDstScalarPerVector>;
auto a_grid_desc_k0_m0_m1_k1 = auto a_grid_desc_k0_m0_m1_k1 =
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1); GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1);
auto b_grid_desc_k0_n0_n1_k1 = auto b_grid_desc_k0_n0_n1_k1 =
...@@ -280,10 +279,18 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t) ...@@ -280,10 +279,18 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n); GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n);
auto block_2_ctile_map = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n); auto block_2_ctile_map = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n);
constexpr bool HasMainKBlockLoop = true; constexpr bool HasMainKBlockLoop = true;
constexpr bool HasDoubleTailKBlockLoop = true; constexpr bool HasDoubleTailKBlockLoop = true;
GridwiseGemm::Run(a_t.data(), b_t.data(), c_t.data(), p_t.data(), a_grid_desc_k0_m0_m1_k1, b_grid_desc_k0_n0_n1_k1, c_grid_desc_m0_m10_m11_n0_n10_n11, block_2_ctile_map, ck::integral_constant<bool, HasMainKBlockLoop>{}, ck::integral_constant<bool, HasDoubleTailKBlockLoop>{}); GridwiseGemm::Run(a_t.data(),
b_t.data(),
c_t.data(),
p_t.data(),
a_grid_desc_k0_m0_m1_k1,
b_grid_desc_k0_n0_n1_k1,
c_grid_desc_m0_m10_m11_n0_n10_n11,
block_2_ctile_map,
ck::integral_constant<bool, HasMainKBlockLoop>{},
ck::integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
} // namespace migraphx } // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment