Unverified Commit 6dfb92bb authored by Jianfeng Yan's avatar Jianfeng Yan Committed by GitHub
Browse files

Conv3d new (#94)



* conv3d compiles but has memory error

* conv3d works

* fix performance issue by using __builtin_amdgc_readfirstlane

* change MakeBlock2CTileMap to MakeDefaultBlock2CTileMap; change c_blockid_to* to cblockid_to*

* clang-format

* remove CK_EXPERIMENTAL_PASS_TENSOR_DECRIPTOR_BY_*; moved wrapper into DeviceConv3d

* format

* remove useless marc

* add comment
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
parent 19c5d6e6
......@@ -111,24 +111,39 @@ struct MagicDivision
}
// magic division for uint32_t
__host__ __device__ static constexpr uint32_t
__device__ static constexpr uint32_t
DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = __umulhi(dividend, multiplier);
return (tmp + dividend) >> shift;
}
__host__ static constexpr uint32_t
DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = static_cast<uint64_t>(dividend) * multiplier >> 32;
return (tmp + dividend) >> shift;
}
// magic division for int32_t
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended
__host__ __device__ static constexpr int32_t
__device__ static constexpr int32_t
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = __umulhi(dividend_u32, multiplier);
return (tmp + dividend_u32) >> shift;
}
__host__ static constexpr int32_t
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = static_cast<uint64_t>(dividend_u32) * multiplier >> 32;
return (tmp + dividend_u32) >> shift;
}
};
} // namespace ck
......
......@@ -8,37 +8,5 @@ namespace ck {
template <index_t N>
using Number = integral_constant<index_t, N>;
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator+(Number<X>, Number<Y>)
{
return Number<X + Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator-(Number<X>, Number<Y>)
{
static_assert(Y <= X, "wrong!");
return Number<X - Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator*(Number<X>, Number<Y>)
{
return Number<X * Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator/(Number<X>, Number<Y>)
{
static_assert(Y > 0, "wrong!");
return Number<X / Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator%(Number<X>, Number<Y>)
{
static_assert(Y > 0, "wrong!");
return Number<X % Y>{};
}
} // namespace ck
#endif
......@@ -13,6 +13,8 @@ __device__ index_t get_wave_local_1d_id() { return threadIdx.x / get_wave_size()
__device__ index_t get_block_1d_id() { return blockIdx.x; }
__device__ index_t get_grid_size() { return gridDim.x; }
} // namespace ck
#endif
......@@ -83,7 +83,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy
void* p_a_k_m0_m1_grid_desc,
void* p_b_k_n0_n1_grid_desc,
void* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
void* p_c_blockid_to_m0_n0_block_cluster_adaptor)
void* p_cblockid_to_m0_n0_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -194,7 +194,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy
auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
auto c_m0_m10_m11_n0_n10_n11_grid_desc =
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
auto c_blockid_to_m0_n0_block_cluster_adaptor =
auto cblockid_to_m0_n0_block_cluster_adaptor =
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
if(hipThreadIdx_x == 0)
......@@ -203,8 +203,8 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy
*static_cast<decltype(b_k_n0_n1_grid_desc)*>(p_b_k_n0_n1_grid_desc) = b_k_n0_n1_grid_desc;
*static_cast<decltype(c_m0_m10_m11_n0_n10_n11_grid_desc)*>(
p_c_m0_m10_m11_n0_n10_n11_grid_desc) = c_m0_m10_m11_n0_n10_n11_grid_desc;
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>(
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor;
*static_cast<decltype(cblockid_to_m0_n0_block_cluster_adaptor)*>(
p_cblockid_to_m0_n0_block_cluster_adaptor) = cblockid_to_m0_n0_block_cluster_adaptor;
};
};
......@@ -219,7 +219,7 @@ extern "C" __global__ void
const void CONSTANT* p_a_k_m0_m1_grid_desc,
const void CONSTANT* p_b_k_n0_n1_grid_desc,
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
const void CONSTANT* p_cblockid_to_m0_n0_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -332,14 +332,13 @@ extern "C" __global__ void
GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
constexpr auto c_m0_m10_m11_n0_n10_n11_grid_desc_tmp =
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp =
constexpr auto cblockid_to_m0_n0_block_cluster_adaptor_tmp =
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc_tmp);
using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc_tmp);
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc_tmp);
using CBlockIdToM0N0BlockClusterAdaptor =
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
using CBlockIdToM0N0BlockClusterAdaptor = decltype(cblockid_to_m0_n0_block_cluster_adaptor_tmp);
const auto a_k_m0_m1_grid_desc =
*reinterpret_cast<const AKM0M1GridDesc*>((const void*)p_a_k_m0_m1_grid_desc);
......@@ -348,9 +347,9 @@ extern "C" __global__ void
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
(const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc);
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
const auto cblockid_to_m0_n0_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
(const void*)p_cblockid_to_m0_n0_block_cluster_adaptor);
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
......@@ -364,7 +363,7 @@ extern "C" __global__ void
a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor,
cblockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
};
......@@ -79,7 +79,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc
void* p_a_k0_m_k1_grid_desc,
void* p_b_k0_n_k1_grid_desc,
void* p_c_m0_m1_m2_n_grid_desc,
void* p_c_blockid_to_m0_n0_block_cluster_adaptor)
void* p_cblockid_to_m0_n0_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -188,7 +188,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
auto c_blockid_to_m0_n0_block_cluster_adaptor =
auto cblockid_to_m0_n0_block_cluster_adaptor =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
if(hipThreadIdx_x == 0)
......@@ -199,8 +199,8 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc
b_k0_n_k1_grid_desc;
*static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) =
c_m0_m1_m2_n_grid_desc;
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>(
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor;
*static_cast<decltype(cblockid_to_m0_n0_block_cluster_adaptor)*>(
p_cblockid_to_m0_n0_block_cluster_adaptor) = cblockid_to_m0_n0_block_cluster_adaptor;
}
};
......@@ -215,7 +215,7 @@ extern "C" __global__ void
const void CONSTANT* p_a_k0_m_k1_grid_desc,
const void CONSTANT* p_b_k0_n_k1_grid_desc,
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
const void CONSTANT* p_cblockid_to_m0_n0_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
......@@ -325,12 +325,11 @@ extern "C" __global__ void
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp =
constexpr auto cblockid_to_m0_n0_block_cluster_adaptor_tmp =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp);
using CBlockIdToM0N0BlockClusterAdaptor =
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
using CBlockIdToM0N0BlockClusterAdaptor = decltype(cblockid_to_m0_n0_block_cluster_adaptor_tmp);
const auto a_k0_m_k1_grid_desc =
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
......@@ -338,9 +337,9 @@ extern "C" __global__ void
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
const auto c_m0_m1_m2_n_grid_desc =
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
const auto cblockid_to_m0_n0_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
(const void*)p_cblockid_to_m0_n0_block_cluster_adaptor);
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
......@@ -354,5 +353,5 @@ extern "C" __global__ void
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor);
cblockid_to_m0_n0_block_cluster_adaptor);
};
......@@ -79,7 +79,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky
void* p_a_k0_m_k1_grid_desc,
void* p_b_k0_n_k1_grid_desc,
void* p_c_m0_m1_m2_n_grid_desc,
void* p_c_blockid_to_m0_n0_block_cluster_adaptor)
void* p_cblockid_to_m0_n0_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -188,7 +188,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
auto c_blockid_to_m0_n0_block_cluster_adaptor =
auto cblockid_to_m0_n0_block_cluster_adaptor =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
if(hipThreadIdx_x == 0)
......@@ -199,8 +199,8 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky
b_k0_n_k1_grid_desc;
*static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) =
c_m0_m1_m2_n_grid_desc;
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>(
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor;
*static_cast<decltype(cblockid_to_m0_n0_block_cluster_adaptor)*>(
p_cblockid_to_m0_n0_block_cluster_adaptor) = cblockid_to_m0_n0_block_cluster_adaptor;
}
};
......@@ -215,7 +215,7 @@ extern "C" __global__ void
const void CONSTANT* p_a_k0_m_k1_grid_desc,
const void CONSTANT* p_b_k0_n_k1_grid_desc,
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
const void CONSTANT* p_cblockid_to_m0_n0_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
......@@ -324,12 +324,11 @@ extern "C" __global__ void
false>;
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp =
constexpr auto cblockid_to_m0_n0_block_cluster_adaptor_tmp =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp);
using CBlockIdToM0N0BlockClusterAdaptor =
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
using CBlockIdToM0N0BlockClusterAdaptor = decltype(cblockid_to_m0_n0_block_cluster_adaptor_tmp);
const auto a_k0_m_k1_grid_desc =
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
......@@ -337,9 +336,9 @@ extern "C" __global__ void
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
const auto c_m0_m1_m2_n_grid_desc =
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
const auto cblockid_to_m0_n0_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
(const void*)p_cblockid_to_m0_n0_block_cluster_adaptor);
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
......@@ -353,5 +352,5 @@ extern "C" __global__ void
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor);
cblockid_to_m0_n0_block_cluster_adaptor);
};
#ifndef CONVOLUTION_UTILITY_HPP
#define CONVOLUTION_UTILITY_HPP
#include <vector>
namespace ck {
namespace tensor_operation {
struct ConvolutionUtility
{
static std::vector<ck::index_t>
ComputeOutputSpatialLengths(std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_strides,
std::vector<ck::index_t> conv_dilations,
std::vector<ck::index_t> in_left_pads,
std::vector<ck::index_t> in_right_pads)
{
if(input_spatial_lengths.size() == 2)
{
assert(filter_spatial_lengths.size() == 2);
assert(conv_strides.size() == 2);
assert(conv_dilations.size() == 2);
assert(in_left_pads.size() == 2);
assert(in_right_pads.size() == 2);
const index_t YEff = (filter_spatial_lengths[0] - 1) * conv_dilations[0] + 1;
const index_t XEff = (filter_spatial_lengths[1] - 1) * conv_dilations[1] + 1;
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
const index_t Ho =
(Hi + in_left_pads[0] + in_right_pads[0] - YEff) / conv_strides[0] + 1;
const index_t Wo =
(Wi + in_left_pads[1] + in_right_pads[1] - XEff) / conv_strides[1] + 1;
return {Ho, Wo};
}
else if(input_spatial_lengths.size() == 3)
{
assert(filter_spatial_lengths.size() == 3);
assert(conv_strides.size() == 3);
assert(conv_dilations.size() == 3);
assert(in_left_pads.size() == 3);
assert(in_right_pads.size() == 3);
const index_t ZEff = (filter_spatial_lengths[0] - 1) * conv_dilations[0] + 1;
const index_t YEff = (filter_spatial_lengths[1] - 1) * conv_dilations[1] + 1;
const index_t XEff = (filter_spatial_lengths[2] - 1) * conv_dilations[2] + 1;
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
const index_t Do =
(Di + in_left_pads[0] + in_right_pads[0] - ZEff) / conv_strides[0] + 1;
const index_t Ho =
(Hi + in_left_pads[1] + in_right_pads[1] - YEff) / conv_strides[1] + 1;
const index_t Wo =
(Wi + in_left_pads[2] + in_right_pads[2] - XEff) / conv_strides[2] + 1;
return {Do, Ho, Wo};
}
else
{
return {};
}
}
};
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -248,7 +248,7 @@ struct DeviceBatchedGemmXdl
c_grid_desc_g_m_n_);
block_2_ctile_map_ =
GridwiseBatchedGemm::MakeBlock2CTileMap(c_grid_desc_g_m_n_, M01, N01);
GridwiseBatchedGemm::MakeDefaultBlock2CTileMap(c_grid_desc_g_m_n_, M01, N01);
}
}
......@@ -261,7 +261,7 @@ struct DeviceBatchedGemmXdl
CGridDesc_G_M_N c_grid_desc_g_m_n_;
typename GridwiseBatchedGemm::CGridDesc_G_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseBatchedGemm::Block2CTileMap block_2_ctile_map_;
typename GridwiseBatchedGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
AElementwiseOperation a_element_op_;
......@@ -327,7 +327,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<typename GridwiseBatchedGemm::Block2CTileMap>,
remove_reference_t<typename GridwiseBatchedGemm::DefaultBlock2CTileMap>,
true>;
ave_time = launch_and_time_kernel(kernel,
......@@ -359,7 +359,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<typename GridwiseBatchedGemm::Block2CTileMap>,
remove_reference_t<typename GridwiseBatchedGemm::DefaultBlock2CTileMap>,
false>;
ave_time = launch_and_time_kernel(kernel,
......
......@@ -590,7 +590,8 @@ struct
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c1_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
......@@ -614,7 +615,7 @@ struct
typename GridwiseGemm::
C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
InElementwiseOperation in_element_op_;
......@@ -694,7 +695,7 @@ struct
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
ave_time = launch_and_time_kernel(
......@@ -738,7 +739,7 @@ struct
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>;
ave_time = launch_and_time_kernel(
......
......@@ -561,7 +561,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c0_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
......@@ -579,7 +580,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
typename GridwiseGemm::
C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
InElementwiseOperation in_element_op_;
......@@ -653,7 +654,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
ave_time = launch_and_time_kernel(
......@@ -692,7 +693,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>;
ave_time = launch_and_time_kernel(
......
......@@ -525,7 +525,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
......@@ -538,7 +539,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
typename GridwiseGemm::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
InElementwiseOperation in_element_op_;
......@@ -628,7 +629,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
ave_time = launch_and_time_kernel(
......@@ -662,7 +663,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>;
ave_time = launch_and_time_kernel(
......
......@@ -415,7 +415,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
......@@ -428,7 +429,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
InElementwiseOperation in_element_op_;
......@@ -471,7 +472,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
arg.N01_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
......@@ -494,7 +495,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
ave_time = launch_and_time_kernel(kernel,
......@@ -525,7 +526,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>;
ave_time = launch_and_time_kernel(kernel,
......
#ifndef DEVICE_CONV3D_FWD_NAIVE_HPP
#define DEVICE_CONV3D_FWD_NAIVE_HPP
#include <iostream>
#include <memory>
#include <sstream>
#include "convolution_utility.hpp"
#include "device.hpp"
#include "device_conv_fwd.hpp"
#include "common_header.hpp"
#include "naive_conv_fwd.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// specialization for #D conv: in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k]
template <typename InDataType,
typename WeiDataType, // WeiDataType must be the same as InDataType
typename OutDataType,
typename AccDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
: public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
{
using DeviceOp = DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K;
using ADataType = InDataType;
using BDataType = WeiDataType;
using CDataType = OutDataType;
// TODO make A/B datatype different
using ABDataType = InDataType;
// Argument
struct Argument : public BaseArgument
{
Argument(const InDataType* p_in,
const WeiDataType* p_wei,
OutDataType* p_out,
const index_t N,
const index_t K,
const index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
: N_{N},
K_{K},
C_{C},
in_spatial_lengths_{input_spatial_lengths},
filter_spatial_lengths_{filter_spatial_lengths},
out_spatial_lengths_{output_spatial_lengths},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads},
in_right_pads_{input_right_pads},
p_in_{p_in},
p_wei_{p_wei},
p_out_{p_out},
in_element_op_{in_element_op},
wei_element_op_{wei_element_op},
out_element_op_{out_element_op}
{
}
// private:
index_t N_;
index_t K_;
index_t C_;
std::vector<index_t> in_spatial_lengths_;
std::vector<index_t> filter_spatial_lengths_;
std::vector<index_t> out_spatial_lengths_;
std::vector<index_t> conv_filter_strides_;
std::vector<index_t> conv_filter_dilations_;
std::vector<index_t> in_left_pads_;
std::vector<index_t> in_right_pads_;
const InDataType* p_in_;
const WeiDataType* p_wei_;
OutDataType* p_out_;
InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_;
OutElementwiseOperation out_element_op_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1)
{
const auto naive_conv3d_fwd =
ref::naive_conv_fwd_ndhwc_kzyxc_ndhwk<InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>;
float ave_time = launch_and_time_kernel(naive_conv3d_fwd,
nrepeat,
dim3(256),
dim3(256),
0,
arg.p_in_,
arg.p_wei_,
arg.p_out_,
arg.N_,
arg.K_,
arg.C_,
arg.in_spatial_lengths_[0],
arg.in_spatial_lengths_[1],
arg.in_spatial_lengths_[2],
arg.filter_spatial_lengths_[0],
arg.filter_spatial_lengths_[1],
arg.filter_spatial_lengths_[2],
arg.out_spatial_lengths_[0],
arg.out_spatial_lengths_[1],
arg.out_spatial_lengths_[2],
arg.conv_filter_strides_[0],
arg.conv_filter_strides_[1],
arg.conv_filter_strides_[2],
arg.conv_filter_dilations_[0],
arg.conv_filter_dilations_[1],
arg.conv_filter_dilations_[2],
arg.in_left_pads_[0],
arg.in_left_pads_[1],
arg.in_left_pads_[2]);
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
std::vector<index_t> out_spatial_lengths =
ConvolutionUtility::ComputeOutputSpatialLengths(arg.in_spatial_lengths_,
arg.filter_spatial_lengths_,
arg.conv_filter_strides_,
arg.conv_filter_dilations_,
arg.in_left_pads_,
arg.in_right_pads_);
bool out_lengths_are_consistent = out_spatial_lengths[0] == arg.out_spatial_lengths_[0] &&
out_spatial_lengths[1] == arg.out_spatial_lengths_[1] &&
out_spatial_lengths[2] == arg.out_spatial_lengths_[2];
return out_lengths_are_consistent;
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const InDataType* p_in,
const WeiDataType* p_wei,
OutDataType* p_out,
const index_t N,
const index_t K,
const index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
{
return Argument{p_in,
p_wei,
p_out,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in,
const void* p_wei,
void* p_out,
const index_t N,
const index_t K,
const index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) override
{
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in),
static_cast<const WeiDataType*>(p_wei),
static_cast<OutDataType*>(p_out),
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K<>";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
#ifndef DEVICE_CONV3D_FWD_XDL_HPP
#define DEVICE_CONV3D_FWD_XDL_HPP
#include <iostream>
#include <memory>
#include <sstream>
#include "device.hpp"
#include "device_conv_fwd.hpp"
#include "common_header.hpp"
#include "tensor_layout.hpp"
#include "convolution_forward_specialization.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_v2r3_for_conv3d(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const index_t num_batches,
const index_t a_batch_stride,
const index_t b_batch_stride,
const index_t c_batch_stride,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
{
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / num_batches);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(a_batch_stride) * g_idx);
const long_index_t b_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(b_batch_stride) * g_idx);
const long_index_t c_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(c_batch_stride) * g_idx);
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
}
// specialization for #D conv: in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k]
template <typename InDataType,
typename WeiDataType, // WeiDataType must be the same as InDataType
typename OutDataType,
typename AccDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t K1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1,
bool ABlockLdsAddExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsAddExtraN,
ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector>
struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
: public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
{
using DeviceOp = DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K;
using ADataType = InDataType;
using BDataType = WeiDataType;
using CDataType = OutDataType;
// TODO make A/B datatype different
using ABDataType = InDataType;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
/*
* \brief Split the number of batches, \p N, into N = B * N1, such that the memory
* space of input and output tensors stays with the value range of index_t, and each subbatch
* can be dealed with GridwiseGemm.
*/
static index_t GetMaxAllowableSubBatchSize(const index_t N,
const index_t K,
const index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths)
{
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
const index_t Do = output_spatial_lengths[0];
const index_t Ho = output_spatial_lengths[1];
const index_t Wo = output_spatial_lengths[2];
// N1 should satisfy that
// 1) N % N1 = 0;
// 2) N1 * (Do * Ho * Wo * K) < (2^31 - 1)
// 3) N1 * (Di * Hi * Wi * C) < (2^31 - 1)
//
// Do NOT confuse (B, N1) in this function with (B, N1) in gridewise GEMM.
auto N1 = N + 1;
const auto stride =
math::max(long_index_t(Do) * Ho * Wo * K, long_index_t(Di) * Hi * Wi * C);
const index_t max_stride = NumericLimits<index_t>::Max();
for(index_t n0 = 1; n0 <= N; ++n0)
{
index_t n1 = N / n0;
if(n0 * n1 == N && long_index_t(n1) * long_index_t(stride) < max_stride)
{
N1 = n1;
break;
}
}
const auto B = N / N1;
if(B * N1 != N)
{
throw std::runtime_error(__func__ +
std::string(": failed to find num_subbatches for conv3d.\n"));
}
return N1;
}
static auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(const index_t N,
const index_t K,
const index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
{
assert(input_spatial_lengths.size() > 2);
assert(filter_spatial_lengths.size() > 2);
assert(conv_filter_strides.size() > 2);
assert(conv_filter_dilations.size() > 2);
assert(input_left_pads.size() > 2);
assert(input_right_pads.size() > 2);
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
const index_t Z = filter_spatial_lengths[0];
const index_t Y = filter_spatial_lengths[1];
const index_t X = filter_spatial_lengths[2];
const index_t Do = output_spatial_lengths[0];
const index_t Ho = output_spatial_lengths[1];
const index_t Wo = output_spatial_lengths[2];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
static_assert(ConvForwardSpecialization == -1, "Not implemented!");
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
static_assert(ConvForwardSpecialization == -1, "Not implemented!");
}
else
{
const auto in_desc_n_di_hi_wi_c =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto wei_desc_k_z_y_x_c =
make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C));
const auto out_desc_n_do_ho_wo_k =
make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K));
const auto descs =
transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk_pad(
in_desc_n_di_hi_wi_c,
wei_desc_k_z_y_x_c,
out_desc_n_do_ho_wo_k,
make_tuple(
conv_filter_strides[0], conv_filter_strides[1], conv_filter_strides[2]),
make_tuple(conv_filter_dilations[0],
conv_filter_dilations[1],
conv_filter_dilations[2]),
make_tuple(input_left_pads[0], input_left_pads[1], input_left_pads[2]),
make_tuple(input_right_pads[0], input_right_pads[1], input_right_pads[2]),
Number<K1>{});
return descs;
}
}
using ABCGridDescs = remove_cvref_t<decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}))>;
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
struct Block2CTileMapMaker
{
Block2CTileMapMaker(index_t num_batches) : num_batches_(num_batches) {}
__host__ __device__ constexpr auto
MakeBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_insert_transform(num_batches_),
make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
const auto globalblockid_to_g_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(num_batches_, M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto globalblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
globalblockid_to_g_m00_m01_n00_n01_block_cluster_adaptor);
return globalblockid_to_m0_n0_block_cluster_adaptor;
}
private:
index_t num_batches_;
};
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
BlockSize,
InDataType,
AccDataType,
OutDataType,
InMemoryDataOperationEnum_t::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
MPerBlock,
NPerBlock,
K0PerBlock,
MPerXDL,
NPerXDL,
K1,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder,
Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder,
Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
7,
CThreadTransferDstScalarPerVector>;
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using Block2CTileMap =
decltype(Block2CTileMapMaker{1}.MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
// Argument
struct Argument : public BaseArgument
{
Argument(const InDataType* p_in,
const WeiDataType* p_wei,
OutDataType* p_out,
const index_t N,
const index_t K,
const index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
index_t M01,
index_t N01,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
: p_a_grid_{p_in},
p_b_grid_{p_wei},
p_c_grid_{p_out},
M01_{M01},
N01_{N01},
in_element_op_{in_element_op},
wei_element_op_{wei_element_op},
out_element_op_{out_element_op}
{
const index_t subbatch_size =
GetMaxAllowableSubBatchSize(N, K, C, input_spatial_lengths, output_spatial_lengths);
num_subbatches_ = N / subbatch_size;
const auto descs =
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(subbatch_size,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
a_batch_stride_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
b_batch_stride_ = 0;
c_batch_stride_ = c_grid_desc_m_n_.GetElementSpaceSize();
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
block_2_ctile_map_ = Block2CTileMapMaker{num_subbatches_}.MakeBlock2CTileMap(
c_grid_desc_m_n_, M01, N01);
}
}
// private:
const InDataType* p_a_grid_;
const WeiDataType* p_b_grid_;
OutDataType* p_c_grid_;
index_t num_subbatches_;
index_t a_batch_stride_;
index_t b_batch_stride_;
index_t c_batch_stride_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
Block2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_;
OutElementwiseOperation out_element_op_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1)
{
{
std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl;
std::cout << "a_grid_desc_k0_m_k1{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "b_grid_desc_k0_n_k1{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "c_grid_desc_m_n{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
// todo: grid_size times arg.num_subbatches_
const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.num_subbatches_;
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0;
if(has_main_k0_block_loop)
{
const auto kernel = kernel_gemm_xdlops_v2r3_for_conv3d<
GridwiseGemm,
InDataType,
OutDataType,
remove_reference_t<AGridDesc_K0_M_K1>,
remove_reference_t<BGridDesc_K0_N_K1>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<Block2CTileMap>,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.num_subbatches_,
arg.a_batch_stride_,
arg.b_batch_stride_,
arg.c_batch_stride_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.in_element_op_,
arg.wei_element_op_,
arg.out_element_op_,
arg.block_2_ctile_map_);
}
else
{
const auto kernel = kernel_gemm_xdlops_v2r3_for_conv3d<
GridwiseGemm,
InDataType,
OutDataType,
remove_reference_t<AGridDesc_K0_M_K1>,
remove_reference_t<BGridDesc_K0_N_K1>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<Block2CTileMap>,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.num_subbatches_,
arg.a_batch_stride_,
arg.b_batch_stride_,
arg.c_batch_stride_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.in_element_op_,
arg.wei_element_op_,
arg.out_element_op_,
arg.block_2_ctile_map_);
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const InDataType* p_in,
const WeiDataType* p_wei,
OutDataType* p_out,
const index_t N,
const index_t K,
const index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
{
return Argument{p_in,
p_wei,
p_out,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
1,
1,
in_element_op,
wei_element_op,
out_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in,
const void* p_wei,
void* p_out,
const index_t N,
const index_t K,
const index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) override
{
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in),
static_cast<const WeiDataType*>(p_wei),
static_cast<OutDataType*>(p_out),
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
1,
1,
in_element_op,
wei_element_op,
out_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -261,7 +261,8 @@ struct DeviceGemmXdl
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
......@@ -274,7 +275,7 @@ struct DeviceGemmXdl
CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
AElementwiseOperation a_element_op_;
......@@ -309,7 +310,7 @@ struct DeviceGemmXdl
arg.N01_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
......@@ -332,7 +333,7 @@ struct DeviceGemmXdl
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
ave_time = launch_and_time_kernel(kernel,
......@@ -363,7 +364,7 @@ struct DeviceGemmXdl
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>;
ave_time = launch_and_time_kernel(kernel,
......
......@@ -221,7 +221,8 @@ struct DeviceGemmXdl_C_Shuffle
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
......@@ -235,7 +236,7 @@ struct DeviceGemmXdl_C_Shuffle
typename GridwiseGemm::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
AElementwiseOperation a_element_op_;
......@@ -295,7 +296,7 @@ struct DeviceGemmXdl_C_Shuffle
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
ave_time = launch_and_time_kernel(
......@@ -329,7 +330,7 @@ struct DeviceGemmXdl_C_Shuffle
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>;
ave_time = launch_and_time_kernel(
......
......@@ -235,7 +235,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
......@@ -254,7 +255,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
typename GridwiseGemm::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
AElementwiseOperation a_element_op_;
......@@ -320,7 +321,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
ave_time = launch_and_time_kernel(
......@@ -359,7 +360,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>;
ave_time = launch_and_time_kernel(
......
......@@ -240,7 +240,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c0_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
......@@ -259,7 +260,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
typename GridwiseGemm::
C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
AElementwiseOperation a_element_op_;
......@@ -325,7 +326,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
ave_time = launch_and_time_kernel(
......@@ -364,7 +365,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>;
ave_time = launch_and_time_kernel(
......
......@@ -274,7 +274,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c1_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
......@@ -298,7 +299,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
typename GridwiseGemm::
C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
AElementwiseOperation a_element_op_;
......@@ -370,7 +371,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
ave_time = launch_and_time_kernel(
......@@ -414,7 +415,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>;
ave_time = launch_and_time_kernel(
......
......@@ -45,6 +45,18 @@ struct NKHW : public BaseTensorLayout
{
};
struct NDHWC : public BaseTensorLayout
{
};
struct KZYXC : public BaseTensorLayout
{
};
struct NDHWK : public BaseTensorLayout
{
};
} // namespace convolution
} // namespace tensor_layout
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment