Commit 6e28a8ac authored by aska-0096's avatar aska-0096
Browse files

format

parent 708fd81f
...@@ -54,12 +54,12 @@ using DeviceConvFwdInstance = ...@@ -54,12 +54,12 @@ using DeviceConvFwdInstance =
256, // BlockSize 256, // BlockSize
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 32, // KPerBlock
8, // K1 8, // K1
16, // MPerWMMA 16, // MPerWMMA
16, // NPerWMMA 16, // NPerWMMA
4, // MRepeat 4, // MRepeat
2, // NRepeat 2, // NRepeat
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder
......
...@@ -140,7 +140,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -140,7 +140,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
// Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] // Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
static auto MakeAGridDescriptor(const std::vector<index_t>& a_gs_ms_ks_lengths_vec, static auto MakeAGridDescriptor(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec) const std::vector<index_t>& a_gs_ms_ks_strides_vec)
{ {
assert(a_gs_ms_ks_lengths_vec.size() == NumDimG + NumDimM + NumDimK && assert(a_gs_ms_ks_lengths_vec.size() == NumDimG + NumDimM + NumDimK &&
a_gs_ms_ks_strides_vec.size() == NumDimG + NumDimM + NumDimK); a_gs_ms_ks_strides_vec.size() == NumDimG + NumDimM + NumDimK);
...@@ -167,7 +167,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -167,7 +167,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
// lengths for K0, K1, ... // lengths for K0, K1, ...
const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds);
const auto a_grid_desc_m_k = [&](){ const auto a_grid_desc_m_k = [&]() {
if constexpr(ASpec == TensorSpecialization::Packed) if constexpr(ASpec == TensorSpecialization::Packed)
{ {
auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
...@@ -229,7 +229,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -229,7 +229,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
// Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] // Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
static auto MakeBGridDescriptor(const std::vector<index_t>& b_gs_ns_ks_lengths_vec, static auto MakeBGridDescriptor(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
const std::vector<index_t>& b_gs_ns_ks_strides_vec) const std::vector<index_t>& b_gs_ns_ks_strides_vec)
{ {
assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK && assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK &&
b_gs_ns_ks_strides_vec.size() == NumDimG + NumDimN + NumDimK); b_gs_ns_ks_strides_vec.size() == NumDimG + NumDimN + NumDimK);
...@@ -256,7 +256,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -256,7 +256,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
// lengths for N0, N1, ... // lengths for N0, N1, ...
const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds);
const auto b_grid_desc_n_k = [&](){ const auto b_grid_desc_n_k = [&]() {
if constexpr(BSpec == TensorSpecialization::Packed) if constexpr(BSpec == TensorSpecialization::Packed)
{ {
auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
...@@ -522,8 +522,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -522,8 +522,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
EGridDesc_G_M_N e_grid_desc_g_m_n_; EGridDesc_G_M_N e_grid_desc_g_m_n_;
}; };
using AGridDesc = decltype(DeviceOp::MakeAGridDescriptor({},{})); using AGridDesc = decltype(DeviceOp::MakeAGridDescriptor({}, {}));
using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor({},{})); using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor({}, {}));
// GridwiseOp // GridwiseOp
using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
...@@ -648,7 +648,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -648,7 +648,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
e_grid_desc_m_n_ = e_grid_desc_m_n_ =
DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01); block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01);
ds_grid_desc_mblock_mperblock_nblock_nperblock = ds_grid_desc_mblock_mperblock_nblock_nperblock =
...@@ -686,7 +685,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -686,7 +685,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
DsGridDesc_G_M_N ds_grid_desc_g_m_n_; DsGridDesc_G_M_N ds_grid_desc_g_m_n_;
EGridDesc_G_M_N e_grid_desc_g_m_n_; EGridDesc_G_M_N e_grid_desc_g_m_n_;
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock; ds_grid_desc_mblock_mperblock_nblock_nperblock;
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -163,16 +162,14 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -163,16 +162,14 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
const auto b_grid_desc_nraw_kraw = const auto b_grid_desc_nraw_kraw =
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB));
make_tuple(I1, StrideB));
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{ {
const auto b_grid_desc_nraw_kraw = const auto b_grid_desc_nraw_kraw =
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1));
make_tuple(StrideB, I1));
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
} }
...@@ -260,10 +257,10 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -260,10 +257,10 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
} }
// Gridwise descriptor, mapping to whole given provblem. // Gridwise descriptor, mapping to whole given provblem.
using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1)); using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1));
using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1)); using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1)); using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
// GridwiseOp // GridwiseOp
using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
......
...@@ -153,16 +153,14 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -153,16 +153,14 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
const auto b_grid_desc_nraw_kraw = const auto b_grid_desc_nraw_kraw =
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB));
make_tuple(I1, StrideB));
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{ {
const auto b_grid_desc_nraw_kraw = const auto b_grid_desc_nraw_kraw =
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1));
make_tuple(StrideB, I1));
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
} }
...@@ -219,9 +217,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -219,9 +217,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
} }
// Gridwise descriptor, mapping to whole given provblem. // Gridwise descriptor, mapping to whole given provblem.
using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1)); using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1));
using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1)); using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_wmma< using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_wmma<
...@@ -303,10 +301,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -303,10 +301,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op}
{ {
a_grid_desc_ = DeviceGemmWmma_CShuffle::MakeAGridDescriptor(M, K, StrideA); a_grid_desc_ = DeviceGemmWmma_CShuffle::MakeAGridDescriptor(M, K, StrideA);
b_grid_desc_k0_n_k1_ = b_grid_desc_k0_n_k1_ = DeviceGemmWmma_CShuffle::MakeBGridDescriptor(K, N, StrideB);
DeviceGemmWmma_CShuffle::MakeBGridDescriptor(K, N, StrideB); c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ = block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
......
...@@ -184,17 +184,16 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -184,17 +184,16 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay> template <typename ALay>
static auto static auto MakeAGridDescriptor(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
MakeAGridDescriptor(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_right_pads)
const std::array<index_t, NDimSpatial>& input_right_pads)
{ {
const auto in_gemmmraw_gemmkraw_desc = const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths, conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
...@@ -210,7 +209,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -210,7 +209,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
const auto M = in_gemmm_gemmk_desc.GetLength(I0); const auto M = in_gemmm_gemmk_desc.GetLength(I0);
const auto K = in_gemmm_gemmk_desc.GetLength(I1); const auto K = in_gemmm_gemmk_desc.GetLength(I1);
assert(K % K1 == 0); assert(K % K1 == 0);
...@@ -244,9 +243,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -244,9 +243,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
} }
template <typename BLay> template <typename BLay>
static auto static auto MakeBGridDescriptor(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
MakeBGridDescriptor(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{ {
const auto wei_gemmnraw_gemmkraw_desc = const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths, conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
...@@ -254,7 +252,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -254,7 +252,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
const auto wei_gemmn_gemmk_desc = const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
const auto N = wei_gemmn_gemmk_desc.GetLength(I0); const auto N = wei_gemmn_gemmk_desc.GetLength(I0);
const auto K = wei_gemmn_gemmk_desc.GetLength(I1); const auto K = wei_gemmn_gemmk_desc.GetLength(I1);
assert(K % K1 == 0); assert(K % K1 == 0);
...@@ -320,7 +318,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -320,7 +318,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>; using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
using AGridDesc = decltype(DeviceOp::MakeAGridDescriptor<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {})); using AGridDesc =
decltype(DeviceOp::MakeAGridDescriptor<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}));
using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor<BLayout>({}, {})); using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor<BLayout>({}, {}));
// GridwiseOp // GridwiseOp
...@@ -414,17 +413,17 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -414,17 +413,17 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)}, e_g_n_k_wos_strides)},
a_grid_desc_{DeviceOp::MakeAGridDescriptor<ALayout>(a_g_n_c_wis_lengths, a_grid_desc_{DeviceOp::MakeAGridDescriptor<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
e_g_n_k_wos_lengths, e_g_n_k_wos_lengths,
e_g_n_k_wos_strides, e_g_n_k_wos_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads)}, input_right_pads)},
b_grid_desc_{DeviceOp::MakeBGridDescriptor<BLayout>(b_g_k_c_xs_lengths, b_grid_desc_{
b_g_k_c_xs_strides)}, DeviceOp::MakeBGridDescriptor<BLayout>(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01)}, block_2_etile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01)},
......
...@@ -513,8 +513,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -513,8 +513,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
} }
template <typename BBlockDesc_> template <typename BBlockDesc_>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto MakeBWaveDescriptor(const BBlockDesc_&)
MakeBWaveDescriptor(const BBlockDesc_&)
{ {
constexpr auto b_wave_desc = [&]() { constexpr auto b_wave_desc = [&]() {
if constexpr(BEnableLds) if constexpr(BEnableLds)
...@@ -595,12 +594,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -595,12 +594,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap> template <typename Block2CTileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
CheckValidity(const AGridDesc& a_grid_desc, const BGridDesc& b_grid_desc,
const BGridDesc& b_grid_desc, const DsGridDesc_M_N& ds_grid_desc_m_n,
const DsGridDesc_M_N& ds_grid_desc_m_n, const EGridDesc_M_N& e_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n, const Block2CTileMap& block_2_ctile_map)
const Block2CTileMap& block_2_ctile_map)
{ {
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value, static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time"); "wrong! K1 need to be known at compile-time");
...@@ -628,16 +626,14 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -628,16 +626,14 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
if constexpr(BEnableLds) if constexpr(BEnableLds)
{ {
return make_tuple(b_grid_desc.GetLength(I1), return make_tuple(b_grid_desc.GetLength(I1),
b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2));
b_grid_desc.GetLength(I2));
} }
else else
{ {
return make_tuple( return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) *
b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) * b_grid_desc.GetLength(I4),
b_grid_desc.GetLength(I4), b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) *
b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) * b_grid_desc.GetLength(I5));
b_grid_desc.GetLength(I5));
} }
}; };
...@@ -747,9 +743,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -747,9 +743,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
max_lds_align) max_lds_align)
: 0; : 0;
static constexpr auto b_block_space_size_aligned = static constexpr auto b_block_space_size_aligned =
BEnableLds ? math::integer_least_multiple( BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(),
MakeBBlockDescriptor().GetElementSpaceSize(), max_lds_align)
max_lds_align)
: 0; : 0;
static constexpr auto a_block_space_offset = 0; static constexpr auto a_block_space_offset = 0;
......
...@@ -327,8 +327,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -327,8 +327,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
} }
template <typename BBlockDesc_> template <typename BBlockDesc_>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto MakeBWaveDescriptor(const BBlockDesc_&)
MakeBWaveDescriptor(const BBlockDesc_&)
{ {
constexpr auto b_wave_desc = [&]() { constexpr auto b_wave_desc = [&]() {
if constexpr(BEnableLds) if constexpr(BEnableLds)
...@@ -394,11 +393,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -394,11 +393,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap> template <typename Block2CTileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
CheckValidity(const AGridDesc& a_grid_desc, const BGridDesc& b_grid_desc,
const BGridDesc& b_grid_desc, const CGridDesc_M_N& c_grid_desc_m_n,
const CGridDesc_M_N& c_grid_desc_m_n, const Block2CTileMap& block_2_ctile_map)
const Block2CTileMap& block_2_ctile_map)
{ {
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value, static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time"); "wrong! K1 need to be known at compile-time");
...@@ -426,16 +424,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -426,16 +424,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
if constexpr(BEnableLds) if constexpr(BEnableLds)
{ {
return make_tuple(b_grid_desc.GetLength(I1), return make_tuple(b_grid_desc.GetLength(I1),
b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2));
b_grid_desc.GetLength(I2));
} }
else else
{ {
return make_tuple( return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) *
b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) * b_grid_desc.GetLength(I4),
b_grid_desc.GetLength(I4), b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) *
b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) * b_grid_desc.GetLength(I5));
b_grid_desc.GetLength(I5));
} }
}; };
......
...@@ -1398,13 +1398,8 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow ...@@ -1398,13 +1398,8 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
if constexpr(IntraRowSwizzlePerm) if constexpr(IntraRowSwizzlePerm)
{ {
temp = __builtin_amdgcn_permlane16( temp = __builtin_amdgcn_permlane16(
temp, temp, type_convert<int>(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0);
type_convert<int>(v_this_row),
0xb3a29180,
0xf7e6d5c4,
1,
0);
v_this_row = type_convert<SrcData>(temp); v_this_row = type_convert<SrcData>(temp);
} }
......
#find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' # find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}'
git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment