Commit ddb0c230 authored by turneram's avatar turneram
Browse files

Formatting

parent 127393f4
......@@ -48,7 +48,6 @@ static constexpr auto I3 = ck::Number<3>{};
static constexpr auto I4 = ck::Number<4>{};
static constexpr auto I5 = ck::Number<5>{};
static constexpr ck::index_t K1 = 1;
static constexpr auto K1Number = ck::Number<K1>{};
......@@ -96,8 +95,6 @@ using CThreadTransferSrcDstAccessOrder = S<0, 1, 2, 3, 4, 5>;
static constexpr ck::index_t CThreadTransferSrcDstVectorDim = 5;
static constexpr ck::index_t CThreadTransferDstScalarPerVector = 4;
static constexpr auto MakeAGridDescriptor_K0_M_K1(ck::index_t M, ck::index_t K, ck::index_t StrideA)
{
assert(K % K1 == 0);
......@@ -194,9 +191,9 @@ static constexpr auto MakeCGridDescriptor_M_N(ck::index_t M, ck::index_t N, ck::
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
c_grid_desc_m_n,
ck::make_tuple(ck::make_right_pad_transform(M, PadM), ck::make_right_pad_transform(N, PadN)),
return transform_tensor_descriptor(c_grid_desc_m_n,
ck::make_tuple(ck::make_right_pad_transform(M, PadM),
ck::make_right_pad_transform(N, PadN)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
}
......@@ -230,12 +227,15 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
constexpr auto cstrides = get_shape_c<V>{}.strides;
constexpr auto cs = cstrides[0];
auto idx = make_index();
if (idx.global == 0)
if(idx.global == 0)
printf("%i %i %i, %i %i %i\n", int(m), int(n), int(k), int(as), int(bs), int(cs));
auto a_grid_desc_k0_m_k1 = MakeAGridDescriptor_K0_M_K1(static_cast<ck::index_t>(m), static_cast<ck::index_t>(k), static_cast<ck::index_t>(as));
auto b_grid_desc_k0_n_k1 = MakeBGridDescriptor_K0_N_K1(static_cast<ck::index_t>(k), static_cast<ck::index_t>(n), static_cast<ck::index_t>(bs));
auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(static_cast<ck::index_t>(m), static_cast<ck::index_t>(n), static_cast<ck::index_t>(cs));
auto a_grid_desc_k0_m_k1 = MakeAGridDescriptor_K0_M_K1(
static_cast<ck::index_t>(m), static_cast<ck::index_t>(k), static_cast<ck::index_t>(as));
auto b_grid_desc_k0_n_k1 = MakeBGridDescriptor_K0_N_K1(
static_cast<ck::index_t>(k), static_cast<ck::index_t>(n), static_cast<ck::index_t>(bs));
auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
static_cast<ck::index_t>(m), static_cast<ck::index_t>(n), static_cast<ck::index_t>(cs));
using GridwiseGemm =
ck::GridwiseGemmDl_km_kn_mn_v1r3<BlockSize,
ADataType,
......@@ -271,7 +271,6 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
auto a_grid_desc_k0_m0_m1_k1 =
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1);
auto b_grid_desc_k0_n0_n1_k1 =
......@@ -280,10 +279,18 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n);
auto block_2_ctile_map = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n);
constexpr bool HasMainKBlockLoop = true;
constexpr bool HasDoubleTailKBlockLoop = true;
GridwiseGemm::Run(a_t.data(), b_t.data(), c_t.data(), p_t.data(), a_grid_desc_k0_m0_m1_k1, b_grid_desc_k0_n0_n1_k1, c_grid_desc_m0_m10_m11_n0_n10_n11, block_2_ctile_map, ck::integral_constant<bool, HasMainKBlockLoop>{}, ck::integral_constant<bool, HasDoubleTailKBlockLoop>{});
GridwiseGemm::Run(a_t.data(),
b_t.data(),
c_t.data(),
p_t.data(),
a_grid_desc_k0_m0_m1_k1,
b_grid_desc_k0_n0_n1_k1,
c_grid_desc_m0_m10_m11_n0_n10_n11,
block_2_ctile_map,
ck::integral_constant<bool, HasMainKBlockLoop>{},
ck::integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
} // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment