Commit c71e140d authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent a5011336
...@@ -69,10 +69,19 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio ...@@ -69,10 +69,19 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvFwdMultipleD_Xdl_CShuffle< using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvFwdMultipleD_Xdl_CShuffle<
NDimSpatial, NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::NHWC, ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::NDHWC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::KZYXC>>,
ck::Tuple<>, ck::Tuple<>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWK,
ck::tensor_layout::convolution::NHWK, ck::tensor_layout::convolution::NHWK,
ck::tensor_layout::convolution::NDHWK>>,
InDataType, InDataType,
WeiDataType, WeiDataType,
AccDataType, AccDataType,
......
...@@ -6,9 +6,24 @@ ...@@ -6,9 +6,24 @@
#include <cstdlib> #include <cstdlib>
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <iterator>
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
std::copy(std::begin(v), std::end(v), std::ostream_iterator<T>(os, " "));
return os;
}
template <typename T, std::size_t N>
std::ostream& operator<<(std::ostream& os, const std::array<T, N>& v)
{
std::copy(std::begin(v), std::end(v), std::ostream_iterator<T>(os, " "));
return os;
}
template <typename... Ts> template <typename... Ts>
std::ostream& operator<<(std::ostream& os, const ck::TensorDescriptor<Ts...>& desc) std::ostream& operator<<(std::ostream& os, const ck::TensorDescriptor<Ts...>& desc)
{ {
......
...@@ -183,7 +183,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -183,7 +183,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay, template <typename ALay,
typename std::enable_if<is_same_v<ALay, tensor_layout::convolution::NWC>, typename std::enable_if<NDimSpatial == 1 &&
is_same_v<ALay, tensor_layout::convolution::NWC>,
bool>::type = false> bool>::type = false>
static auto static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths, MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
...@@ -294,7 +295,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -294,7 +295,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
} }
template <typename ALay, template <typename ALay,
typename std::enable_if<is_same_v<ALay, tensor_layout::convolution::NHWC>, typename std::enable_if<NDimSpatial == 2 &&
is_same_v<ALay, tensor_layout::convolution::NHWC>,
bool>::type = false> bool>::type = false>
static auto static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths, MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
...@@ -419,7 +421,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -419,7 +421,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
} }
template <typename ALay, template <typename ALay,
typename std::enable_if<is_same_v<ALay, tensor_layout::convolution::NDHWC>, typename std::enable_if<NDimSpatial == 3 &&
is_same_v<ALay, tensor_layout::convolution::NDHWC>,
bool>::type = false> bool>::type = false>
static auto static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths, MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
...@@ -925,16 +928,6 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -925,16 +928,6 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return false; return false;
} }
// check tensor size: cannot be larger than 2GB each
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(arg.a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() * sizeof(ADataType) > TwoGB ||
arg.b_grid_desc_bk0_n_bk1_.GetElementSpaceSize() * sizeof(BDataType) > TwoGB ||
arg.e_grid_desc_m_n_.GetElementSpaceSize() * sizeof(EDataType) > TwoGB)
{
return false;
}
// check ConvolutionForwardSpecialization // check ConvolutionForwardSpecialization
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
...@@ -1020,7 +1013,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -1020,7 +1013,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return false; return false;
} }
// Gridwise GEMM size // check Gridwise GEMM
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_, arg.b_grid_desc_n_k_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
......
...@@ -26,11 +26,11 @@ namespace ck { ...@@ -26,11 +26,11 @@ namespace ck {
// E = cde_op(C, D0, D1, ...) // E = cde_op(C, D0, D1, ...)
// Assume: // Assume:
// D0, D1, ... and E have the same layout // D0, D1, ... and E have the same layout
template <typename FloatAB, template <typename ABDataType, // FIXME: don't assume A/B have same datatype
typename FloatGemmAcc, typename AccDataType,
typename FloatCShuffle, typename CShuffleDataType,
typename DsDataType, typename DsDataType,
typename FloatE, typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
...@@ -160,8 +160,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -160,8 +160,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB), sizeof(ABDataType),
c_block_size * sizeof(FloatCShuffle)); c_block_size * sizeof(CShuffleDataType));
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -256,6 +256,16 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -256,6 +256,16 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
} }
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// check tensor size: cannot be larger than 2GB each
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{
return false;
}
return true; return true;
} }
...@@ -283,10 +293,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -283,10 +293,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename Block2ETileMap> typename Block2ETileMap>
__device__ static void __device__ static void
Run(const FloatAB* __restrict__ p_a_grid, Run(const ABDataType* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
DsGridPointer p_ds_grid, DsGridPointer p_ds_grid,
FloatE* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
...@@ -355,8 +365,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -355,8 +365,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
Sequence<AK0PerBlock, MPerBlock, AK1>, Sequence<AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, ABDataType,
FloatAB, ABDataType,
decltype(a_grid_desc_ak0_m_ak1), decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -386,8 +396,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -386,8 +396,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
Sequence<BK0PerBlock, NPerBlock, BK1>, Sequence<BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, ABDataType,
FloatAB, ABDataType,
decltype(b_grid_desc_bk0_n_bk1), decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -415,13 +425,14 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -415,13 +425,14 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
constexpr index_t KPack = math::max( constexpr index_t KPack =
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); math::max(math::lcm(AK1, BK1),
MfmaSelector<ABDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
FloatAB, ABDataType,
FloatGemmAcc, AccDataType,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
MPerXdl, MPerXdl,
...@@ -438,10 +449,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -438,10 +449,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<ABDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned, static_cast<ABDataType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
...@@ -502,7 +513,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -502,7 +513,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared), static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
...@@ -554,8 +565,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -554,8 +565,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// shuffle: threadwise copy C from VGPR to LDS // shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds = auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatGemmAcc, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
FloatCShuffle, CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
...@@ -612,8 +623,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -612,8 +623,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// blockwise copy C/D/E between LDS and global // blockwise copy C/D/E between LDS and global
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock, ThisThreadBlock,
decltype(container_concat(make_tuple(FloatCShuffle{}), DsDataType{})), decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<FloatE>, Tuple<EDataType>,
decltype(c_ds_desc_refs), decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation, CDEElementwiseOperation,
......
...@@ -15,8 +15,7 @@ ...@@ -15,8 +15,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/host_utility/io.hpp"
#include "ck/library/utility/io.hpp"
namespace ck { namespace ck {
namespace utils { namespace utils {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <iostream>
#include <vector>
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
std::copy(std::begin(v), std::end(v), std::ostream_iterator<T>(os, " "));
return os;
}
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host_utility/io.hpp"
#include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/io.hpp"
namespace ck { namespace ck {
namespace utils { namespace utils {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/device_utility/hip_check_error.hpp" #include "ck/host_utility/hip_check_error.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment