Commit c71e140d authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent a5011336
......@@ -69,10 +69,19 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
template <ck::index_t NDimSpatial>
using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvFwdMultipleD_Xdl_CShuffle<
NDimSpatial,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::NDHWC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::KZYXC>>,
ck::Tuple<>,
ck::tensor_layout::convolution::NHWK,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWK,
ck::tensor_layout::convolution::NHWK,
ck::tensor_layout::convolution::NDHWK>>,
InDataType,
WeiDataType,
AccDataType,
......
......@@ -6,9 +6,24 @@
#include <cstdlib>
#include <iostream>
#include <vector>
#include <iterator>
#include "ck/tensor_description/tensor_descriptor.hpp"
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
std::copy(std::begin(v), std::end(v), std::ostream_iterator<T>(os, " "));
return os;
}
template <typename T, std::size_t N>
std::ostream& operator<<(std::ostream& os, const std::array<T, N>& v)
{
std::copy(std::begin(v), std::end(v), std::ostream_iterator<T>(os, " "));
return os;
}
template <typename... Ts>
std::ostream& operator<<(std::ostream& os, const ck::TensorDescriptor<Ts...>& desc)
{
......
......@@ -183,7 +183,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay,
typename std::enable_if<is_same_v<ALay, tensor_layout::convolution::NWC>,
typename std::enable_if<NDimSpatial == 1 &&
is_same_v<ALay, tensor_layout::convolution::NWC>,
bool>::type = false>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
......@@ -294,7 +295,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
template <typename ALay,
typename std::enable_if<is_same_v<ALay, tensor_layout::convolution::NHWC>,
typename std::enable_if<NDimSpatial == 2 &&
is_same_v<ALay, tensor_layout::convolution::NHWC>,
bool>::type = false>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
......@@ -419,7 +421,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
template <typename ALay,
typename std::enable_if<is_same_v<ALay, tensor_layout::convolution::NDHWC>,
typename std::enable_if<NDimSpatial == 3 &&
is_same_v<ALay, tensor_layout::convolution::NDHWC>,
bool>::type = false>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
......@@ -925,16 +928,6 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return false;
}
// check tensor size: cannot be larger than 2GB each
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(arg.a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() * sizeof(ADataType) > TwoGB ||
arg.b_grid_desc_bk0_n_bk1_.GetElementSpaceSize() * sizeof(BDataType) > TwoGB ||
arg.e_grid_desc_m_n_.GetElementSpaceSize() * sizeof(EDataType) > TwoGB)
{
return false;
}
// check ConvolutionForwardSpecialization
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
......@@ -1020,7 +1013,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return false;
}
// Gridwise GEMM size
// check Gridwise GEMM
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.e_grid_desc_m_n_,
......
......@@ -26,11 +26,11 @@ namespace ck {
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename FloatAB,
typename FloatGemmAcc,
typename FloatCShuffle,
template <typename ABDataType, // FIXME: don't assume A/B have same datatype
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename FloatE,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
......@@ -160,8 +160,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB),
c_block_size * sizeof(FloatCShuffle));
sizeof(ABDataType),
c_block_size * sizeof(CShuffleDataType));
}
__host__ __device__ static constexpr auto
......@@ -256,6 +256,16 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// check tensor size: cannot be larger than 2GB each
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{
return false;
}
return true;
}
......@@ -283,10 +293,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
typename BGridDesc_BK0_N_BK1,
typename Block2ETileMap>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
Run(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsGridPointer p_ds_grid,
FloatE* __restrict__ p_e_grid,
EDataType* __restrict__ p_e_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
......@@ -355,8 +365,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
Sequence<AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
ABDataType,
ABDataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
......@@ -386,8 +396,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
Sequence<BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
ABDataType,
ABDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
......@@ -415,13 +425,14 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<ABDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
FloatAB,
FloatGemmAcc,
ABDataType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
MPerXdl,
......@@ -438,10 +449,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
static_cast<ABDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
static_cast<ABDataType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
......@@ -502,7 +513,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared),
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
......@@ -554,8 +565,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatGemmAcc,
FloatCShuffle,
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
......@@ -612,8 +623,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// blockwise copy C/D/E between LDS and global
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock,
decltype(container_concat(make_tuple(FloatCShuffle{}), DsDataType{})),
Tuple<FloatE>,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation,
......
......@@ -15,8 +15,7 @@
#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/utility/io.hpp"
#include "ck/host_utility/io.hpp"
namespace ck {
namespace utils {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <iostream>
#include <vector>
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
std::copy(std::begin(v), std::end(v), std::ostream_iterator<T>(os, " "));
return os;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host_utility/io.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/io.hpp"
namespace ck {
namespace utils {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/device_utility/hip_check_error.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/library/utility/device_memory.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment