Commit 6e28a8ac authored by aska-0096's avatar aska-0096
Browse files

format

parent 708fd81f
...@@ -167,7 +167,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -167,7 +167,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
// lengths for K0, K1, ... // lengths for K0, K1, ...
const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds);
const auto a_grid_desc_m_k = [&](){ const auto a_grid_desc_m_k = [&]() {
if constexpr(ASpec == TensorSpecialization::Packed) if constexpr(ASpec == TensorSpecialization::Packed)
{ {
auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
...@@ -256,7 +256,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -256,7 +256,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
// lengths for N0, N1, ... // lengths for N0, N1, ...
const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds);
const auto b_grid_desc_n_k = [&](){ const auto b_grid_desc_n_k = [&]() {
if constexpr(BSpec == TensorSpecialization::Packed) if constexpr(BSpec == TensorSpecialization::Packed)
{ {
auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
...@@ -522,8 +522,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -522,8 +522,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
EGridDesc_G_M_N e_grid_desc_g_m_n_; EGridDesc_G_M_N e_grid_desc_g_m_n_;
}; };
using AGridDesc = decltype(DeviceOp::MakeAGridDescriptor({},{})); using AGridDesc = decltype(DeviceOp::MakeAGridDescriptor({}, {}));
using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor({},{})); using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor({}, {}));
// GridwiseOp // GridwiseOp
using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
...@@ -648,7 +648,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -648,7 +648,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
e_grid_desc_m_n_ = e_grid_desc_m_n_ =
DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01); block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01);
ds_grid_desc_mblock_mperblock_nblock_nperblock = ds_grid_desc_mblock_mperblock_nblock_nperblock =
...@@ -686,7 +685,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -686,7 +685,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
DsGridDesc_G_M_N ds_grid_desc_g_m_n_; DsGridDesc_G_M_N ds_grid_desc_g_m_n_;
EGridDesc_G_M_N e_grid_desc_g_m_n_; EGridDesc_G_M_N e_grid_desc_g_m_n_;
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock; ds_grid_desc_mblock_mperblock_nblock_nperblock;
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -163,16 +162,14 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -163,16 +162,14 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
const auto b_grid_desc_nraw_kraw = const auto b_grid_desc_nraw_kraw =
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB));
make_tuple(I1, StrideB));
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{ {
const auto b_grid_desc_nraw_kraw = const auto b_grid_desc_nraw_kraw =
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1));
make_tuple(StrideB, I1));
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
} }
......
...@@ -153,16 +153,14 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -153,16 +153,14 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
const auto b_grid_desc_nraw_kraw = const auto b_grid_desc_nraw_kraw =
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB));
make_tuple(I1, StrideB));
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{ {
const auto b_grid_desc_nraw_kraw = const auto b_grid_desc_nraw_kraw =
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1));
make_tuple(StrideB, I1));
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
} }
...@@ -304,8 +302,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -304,8 +302,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
c_element_op_{c_element_op} c_element_op_{c_element_op}
{ {
a_grid_desc_ = DeviceGemmWmma_CShuffle::MakeAGridDescriptor(M, K, StrideA); a_grid_desc_ = DeviceGemmWmma_CShuffle::MakeAGridDescriptor(M, K, StrideA);
b_grid_desc_k0_n_k1_ = b_grid_desc_k0_n_k1_ = DeviceGemmWmma_CShuffle::MakeBGridDescriptor(K, N, StrideB);
DeviceGemmWmma_CShuffle::MakeBGridDescriptor(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC); c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ = block_2_ctile_map_ =
......
...@@ -184,8 +184,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -184,8 +184,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay> template <typename ALay>
static auto static auto MakeAGridDescriptor(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
MakeAGridDescriptor(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
...@@ -244,8 +243,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -244,8 +243,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
} }
template <typename BLay> template <typename BLay>
static auto static auto MakeBGridDescriptor(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
MakeBGridDescriptor(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides) const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{ {
const auto wei_gemmnraw_gemmkraw_desc = const auto wei_gemmnraw_gemmkraw_desc =
...@@ -320,7 +318,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -320,7 +318,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>; using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
using AGridDesc = decltype(DeviceOp::MakeAGridDescriptor<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {})); using AGridDesc =
decltype(DeviceOp::MakeAGridDescriptor<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}));
using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor<BLayout>({}, {})); using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor<BLayout>({}, {}));
// GridwiseOp // GridwiseOp
...@@ -423,8 +422,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -423,8 +422,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads)}, input_right_pads)},
b_grid_desc_{DeviceOp::MakeBGridDescriptor<BLayout>(b_g_k_c_xs_lengths, b_grid_desc_{
b_g_k_c_xs_strides)}, DeviceOp::MakeBGridDescriptor<BLayout>(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01)}, block_2_etile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01)},
......
...@@ -513,8 +513,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -513,8 +513,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
} }
template <typename BBlockDesc_> template <typename BBlockDesc_>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto MakeBWaveDescriptor(const BBlockDesc_&)
MakeBWaveDescriptor(const BBlockDesc_&)
{ {
constexpr auto b_wave_desc = [&]() { constexpr auto b_wave_desc = [&]() {
if constexpr(BEnableLds) if constexpr(BEnableLds)
...@@ -595,8 +594,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -595,8 +594,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap> template <typename Block2CTileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
CheckValidity(const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc, const BGridDesc& b_grid_desc,
const DsGridDesc_M_N& ds_grid_desc_m_n, const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n, const EGridDesc_M_N& e_grid_desc_m_n,
...@@ -628,13 +626,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -628,13 +626,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
if constexpr(BEnableLds) if constexpr(BEnableLds)
{ {
return make_tuple(b_grid_desc.GetLength(I1), return make_tuple(b_grid_desc.GetLength(I1),
b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2));
b_grid_desc.GetLength(I2));
} }
else else
{ {
return make_tuple( return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) *
b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) *
b_grid_desc.GetLength(I4), b_grid_desc.GetLength(I4),
b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) * b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) *
b_grid_desc.GetLength(I5)); b_grid_desc.GetLength(I5));
...@@ -747,8 +743,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -747,8 +743,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
max_lds_align) max_lds_align)
: 0; : 0;
static constexpr auto b_block_space_size_aligned = static constexpr auto b_block_space_size_aligned =
BEnableLds ? math::integer_least_multiple( BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(),
MakeBBlockDescriptor().GetElementSpaceSize(),
max_lds_align) max_lds_align)
: 0; : 0;
......
...@@ -327,8 +327,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -327,8 +327,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
} }
template <typename BBlockDesc_> template <typename BBlockDesc_>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto MakeBWaveDescriptor(const BBlockDesc_&)
MakeBWaveDescriptor(const BBlockDesc_&)
{ {
constexpr auto b_wave_desc = [&]() { constexpr auto b_wave_desc = [&]() {
if constexpr(BEnableLds) if constexpr(BEnableLds)
...@@ -394,8 +393,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -394,8 +393,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap> template <typename Block2CTileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
CheckValidity(const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc, const BGridDesc& b_grid_desc,
const CGridDesc_M_N& c_grid_desc_m_n, const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
...@@ -426,13 +424,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -426,13 +424,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
if constexpr(BEnableLds) if constexpr(BEnableLds)
{ {
return make_tuple(b_grid_desc.GetLength(I1), return make_tuple(b_grid_desc.GetLength(I1),
b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2));
b_grid_desc.GetLength(I2));
} }
else else
{ {
return make_tuple( return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) *
b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) *
b_grid_desc.GetLength(I4), b_grid_desc.GetLength(I4),
b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) * b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) *
b_grid_desc.GetLength(I5)); b_grid_desc.GetLength(I5));
......
...@@ -1399,12 +1399,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow ...@@ -1399,12 +1399,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
if constexpr(IntraRowSwizzlePerm) if constexpr(IntraRowSwizzlePerm)
{ {
temp = __builtin_amdgcn_permlane16( temp = __builtin_amdgcn_permlane16(
temp, temp, type_convert<int>(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0);
type_convert<int>(v_this_row),
0xb3a29180,
0xf7e6d5c4,
1,
0);
v_this_row = type_convert<SrcData>(temp); v_this_row = type_convert<SrcData>(temp);
} }
......
#find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' # find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}'
git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment