"...resnet50_tensorflow.git" did not exist on "45fd6cee4f846870616ec77a307924e944a7336d"
Commit 34520806 authored by Paul's avatar Paul
Browse files

Format

parent 2ca29096
...@@ -40,11 +40,14 @@ __device__ void ck_gemm(const A& a, const B& b, const C& c) ...@@ -40,11 +40,14 @@ __device__ void ck_gemm(const A& a, const B& b, const C& c)
constexpr auto a_grid_desc_ak0_m_ak1 = gemm.MakeAGridDescriptor_AK0_M_AK1(to_ck_tensor<A>()); constexpr auto a_grid_desc_ak0_m_ak1 = gemm.MakeAGridDescriptor_AK0_M_AK1(to_ck_tensor<A>());
constexpr auto b_grid_desc_bk0_n_bk1 = gemm.MakeBGridDescriptor_BK0_N_BK1(to_ck_tensor<B>()); constexpr auto b_grid_desc_bk0_n_bk1 = gemm.MakeBGridDescriptor_BK0_N_BK1(to_ck_tensor<B>());
constexpr auto c_grid_desc_m_n = gemm.MakeCGridDescriptor_M_N(to_ck_tensor<C>()); constexpr auto c_grid_desc_m_n = gemm.MakeCGridDescriptor_M_N(to_ck_tensor<C>());
constexpr auto block_2_ctile_map = gemm.MakeDefaultBlock2CTileMap(c_grid_desc_m_n); constexpr auto block_2_ctile_map = gemm.MakeDefaultBlock2CTileMap(c_grid_desc_m_n);
using GridwiseGemm = typename G::template GridwiseGemm<decltype(a_grid_desc_ak0_m_ak1), decltype(b_grid_desc_bk0_n_bk1), decltype(c_grid_desc_m_n)>; using GridwiseGemm = typename G::template GridwiseGemm<decltype(a_grid_desc_ak0_m_ak1),
// static_assert(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, c_grid_desc_m_n, block_2_ctile_map)); decltype(b_grid_desc_bk0_n_bk1),
decltype(c_grid_desc_m_n)>;
// static_assert(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1,
// c_grid_desc_m_n, block_2_ctile_map));
constexpr auto c_grid_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
......
...@@ -149,23 +149,21 @@ template <typename ALayout, ...@@ -149,23 +149,21 @@ template <typename ALayout,
ck::index_t CShuffleNXdlPerWavePerShuffle, ck::index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock, ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler() ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()>
> struct CKDeviceGemm
struct CKDeviceGemm
{ {
static constexpr auto I0 = ck::Number<0>{}; static constexpr auto I0 = ck::Number<0>{};
static constexpr auto I1 = ck::Number<1>{}; static constexpr auto I1 = ck::Number<1>{};
static constexpr auto I2 = ck::Number<2>{}; static constexpr auto I2 = ck::Number<2>{};
static constexpr auto I3 = ck::Number<3>{}; static constexpr auto I3 = ck::Number<3>{};
template<class Descriptor> template <class Descriptor>
static constexpr auto static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const Descriptor& a_grid_desc_mraw_kraw)
MakeAGridDescriptor_AK0_M_AK1(const Descriptor& a_grid_desc_mraw_kraw)
{ {
const auto MRaw = a_grid_desc_mraw_kraw.GetLength(I0); const auto MRaw = a_grid_desc_mraw_kraw.GetLength(I0);
const auto KRaw = a_grid_desc_mraw_kraw.GetLength(I1); const auto KRaw = a_grid_desc_mraw_kraw.GetLength(I1);
const auto M = ck::math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; const auto M = ck::math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto K = ck::math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; const auto K = ck::math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
const auto KPad = K - KRaw; const auto KPad = K - KRaw;
...@@ -253,14 +251,13 @@ struct CKDeviceGemm ...@@ -253,14 +251,13 @@ struct CKDeviceGemm
} }
} }
template<class Descriptor> template <class Descriptor>
static constexpr auto static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const Descriptor& b_grid_desc_nraw_kraw)
MakeBGridDescriptor_BK0_N_BK1(const Descriptor& b_grid_desc_nraw_kraw)
{ {
const auto NRaw = b_grid_desc_nraw_kraw.GetLength(I0); const auto NRaw = b_grid_desc_nraw_kraw.GetLength(I0);
const auto KRaw = b_grid_desc_nraw_kraw.GetLength(I1); const auto KRaw = b_grid_desc_nraw_kraw.GetLength(I1);
const auto N = ck::math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; const auto N = ck::math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto K = ck::math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; const auto K = ck::math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto NPad = N - NRaw; const auto NPad = N - NRaw;
const auto KPad = K - KRaw; const auto KPad = K - KRaw;
...@@ -348,14 +345,13 @@ struct CKDeviceGemm ...@@ -348,14 +345,13 @@ struct CKDeviceGemm
} }
} }
template<class Descriptor> template <class Descriptor>
static constexpr auto static constexpr auto MakeCGridDescriptor_M_N(const Descriptor& c_grid_desc_mraw_nraw)
MakeCGridDescriptor_M_N(const Descriptor& c_grid_desc_mraw_nraw)
{ {
const auto MRaw = c_grid_desc_mraw_nraw.GetLength(I0); const auto MRaw = c_grid_desc_mraw_nraw.GetLength(I0);
const auto NRaw = c_grid_desc_mraw_nraw.GetLength(I1); const auto NRaw = c_grid_desc_mraw_nraw.GetLength(I1);
const auto M = ck::math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; const auto M = ck::math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = ck::math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; const auto N = ck::math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
const auto NPad = N - NRaw; const auto NPad = N - NRaw;
...@@ -407,8 +403,8 @@ struct CKDeviceGemm ...@@ -407,8 +403,8 @@ struct CKDeviceGemm
// using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1()); // using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1());
// using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N()); // using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N());
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
template<class CGridDesc_M_N> template <class CGridDesc_M_N>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
...@@ -416,7 +412,7 @@ struct CKDeviceGemm ...@@ -416,7 +412,7 @@ struct CKDeviceGemm
c_grid_desc_m_n); c_grid_desc_m_n);
} }
template<class AGridDesc_AK0_M_AK1, class BGridDesc_BK0_N_BK1, class CGridDesc_M_N> template <class AGridDesc_AK0_M_AK1, class BGridDesc_BK0_N_BK1, class CGridDesc_M_N>
using GridwiseGemm = ck::GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< using GridwiseGemm = ck::GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmAccDataType, GemmAccDataType,
...@@ -461,7 +457,7 @@ struct CKDeviceGemm ...@@ -461,7 +457,7 @@ struct CKDeviceGemm
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
AElementwiseOperation a_element_op{}; AElementwiseOperation a_element_op{};
BElementwiseOperation b_element_op{}; BElementwiseOperation b_element_op{};
CElementwiseOperation c_element_op{}; CElementwiseOperation c_element_op{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment