"vscode:/vscode.git/clone" did not exist on "191a3f6f9d5fe4f0e4b96e3a6be95d50a14efd63"
Unverified Commit 6dfb92bb authored by Jianfeng Yan's avatar Jianfeng Yan Committed by GitHub
Browse files

Conv3d new (#94)



* conv3d compiles but has memory error

* conv3d works

* fix performance issue by using __builtin_amdgc_readfirstlane

* change MakeBlock2CTileMap to MakeDefaultBlock2CTileMap; change c_blockid_to* to cblockid_to*

* clang-format

* remove CK_EXPERIMENTAL_PASS_TENSOR_DECRIPTOR_BY_*; moved wrapper into DeviceConv3d

* format

* remove useless marc

* add comment
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
parent 19c5d6e6
...@@ -111,24 +111,39 @@ struct MagicDivision ...@@ -111,24 +111,39 @@ struct MagicDivision
} }
// magic division for uint32_t // magic division for uint32_t
__host__ __device__ static constexpr uint32_t __device__ static constexpr uint32_t
DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift) DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{ {
uint32_t tmp = __umulhi(dividend, multiplier); uint32_t tmp = __umulhi(dividend, multiplier);
return (tmp + dividend) >> shift; return (tmp + dividend) >> shift;
} }
__host__ static constexpr uint32_t
DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = static_cast<uint64_t>(dividend) * multiplier >> 32;
return (tmp + dividend) >> shift;
}
// magic division for int32_t // magic division for int32_t
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be // HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// non-negative for result to be correct // non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended // TODO: figure out how to do magic number divison for int32_t as dividended
__host__ __device__ static constexpr int32_t __device__ static constexpr int32_t
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{ {
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32); uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = __umulhi(dividend_u32, multiplier); uint32_t tmp = __umulhi(dividend_u32, multiplier);
return (tmp + dividend_u32) >> shift; return (tmp + dividend_u32) >> shift;
} }
__host__ static constexpr int32_t
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = static_cast<uint64_t>(dividend_u32) * multiplier >> 32;
return (tmp + dividend_u32) >> shift;
}
}; };
} // namespace ck } // namespace ck
......
...@@ -8,37 +8,5 @@ namespace ck { ...@@ -8,37 +8,5 @@ namespace ck {
template <index_t N> template <index_t N>
using Number = integral_constant<index_t, N>; using Number = integral_constant<index_t, N>;
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator+(Number<X>, Number<Y>)
{
return Number<X + Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator-(Number<X>, Number<Y>)
{
static_assert(Y <= X, "wrong!");
return Number<X - Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator*(Number<X>, Number<Y>)
{
return Number<X * Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator/(Number<X>, Number<Y>)
{
static_assert(Y > 0, "wrong!");
return Number<X / Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator%(Number<X>, Number<Y>)
{
static_assert(Y > 0, "wrong!");
return Number<X % Y>{};
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -13,6 +13,8 @@ __device__ index_t get_wave_local_1d_id() { return threadIdx.x / get_wave_size() ...@@ -13,6 +13,8 @@ __device__ index_t get_wave_local_1d_id() { return threadIdx.x / get_wave_size()
__device__ index_t get_block_1d_id() { return blockIdx.x; } __device__ index_t get_block_1d_id() { return blockIdx.x; }
__device__ index_t get_grid_size() { return gridDim.x; }
} // namespace ck } // namespace ck
#endif #endif
...@@ -83,7 +83,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy ...@@ -83,7 +83,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy
void* p_a_k_m0_m1_grid_desc, void* p_a_k_m0_m1_grid_desc,
void* p_b_k_n0_n1_grid_desc, void* p_b_k_n0_n1_grid_desc,
void* p_c_m0_m10_m11_n0_n10_n11_grid_desc, void* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
void* p_c_blockid_to_m0_n0_block_cluster_adaptor) void* p_cblockid_to_m0_n0_block_cluster_adaptor)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -194,7 +194,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy ...@@ -194,7 +194,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy
auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
auto c_m0_m10_m11_n0_n10_n11_grid_desc = auto c_m0_m10_m11_n0_n10_n11_grid_desc =
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
auto c_blockid_to_m0_n0_block_cluster_adaptor = auto cblockid_to_m0_n0_block_cluster_adaptor =
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
if(hipThreadIdx_x == 0) if(hipThreadIdx_x == 0)
...@@ -203,8 +203,8 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy ...@@ -203,8 +203,8 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy
*static_cast<decltype(b_k_n0_n1_grid_desc)*>(p_b_k_n0_n1_grid_desc) = b_k_n0_n1_grid_desc; *static_cast<decltype(b_k_n0_n1_grid_desc)*>(p_b_k_n0_n1_grid_desc) = b_k_n0_n1_grid_desc;
*static_cast<decltype(c_m0_m10_m11_n0_n10_n11_grid_desc)*>( *static_cast<decltype(c_m0_m10_m11_n0_n10_n11_grid_desc)*>(
p_c_m0_m10_m11_n0_n10_n11_grid_desc) = c_m0_m10_m11_n0_n10_n11_grid_desc; p_c_m0_m10_m11_n0_n10_n11_grid_desc) = c_m0_m10_m11_n0_n10_n11_grid_desc;
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>( *static_cast<decltype(cblockid_to_m0_n0_block_cluster_adaptor)*>(
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor; p_cblockid_to_m0_n0_block_cluster_adaptor) = cblockid_to_m0_n0_block_cluster_adaptor;
}; };
}; };
...@@ -219,7 +219,7 @@ extern "C" __global__ void ...@@ -219,7 +219,7 @@ extern "C" __global__ void
const void CONSTANT* p_a_k_m0_m1_grid_desc, const void CONSTANT* p_a_k_m0_m1_grid_desc,
const void CONSTANT* p_b_k_n0_n1_grid_desc, const void CONSTANT* p_b_k_n0_n1_grid_desc,
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc, const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) const void CONSTANT* p_cblockid_to_m0_n0_block_cluster_adaptor)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -332,14 +332,13 @@ extern "C" __global__ void ...@@ -332,14 +332,13 @@ extern "C" __global__ void
GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
constexpr auto c_m0_m10_m11_n0_n10_n11_grid_desc_tmp = constexpr auto c_m0_m10_m11_n0_n10_n11_grid_desc_tmp =
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp = constexpr auto cblockid_to_m0_n0_block_cluster_adaptor_tmp =
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc_tmp); using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc_tmp);
using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc_tmp); using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc_tmp);
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc_tmp); using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc_tmp);
using CBlockIdToM0N0BlockClusterAdaptor = using CBlockIdToM0N0BlockClusterAdaptor = decltype(cblockid_to_m0_n0_block_cluster_adaptor_tmp);
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
const auto a_k_m0_m1_grid_desc = const auto a_k_m0_m1_grid_desc =
*reinterpret_cast<const AKM0M1GridDesc*>((const void*)p_a_k_m0_m1_grid_desc); *reinterpret_cast<const AKM0M1GridDesc*>((const void*)p_a_k_m0_m1_grid_desc);
...@@ -348,9 +347,9 @@ extern "C" __global__ void ...@@ -348,9 +347,9 @@ extern "C" __global__ void
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>( *reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
(const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc); (const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc);
const auto c_blockid_to_m0_n0_block_cluster_adaptor = const auto cblockid_to_m0_n0_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>( *reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); (const void*)p_cblockid_to_m0_n0_block_cluster_adaptor);
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -364,7 +363,7 @@ extern "C" __global__ void ...@@ -364,7 +363,7 @@ extern "C" __global__ void
a_k_m0_m1_grid_desc, a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc, b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor, cblockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
}; };
...@@ -79,7 +79,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc ...@@ -79,7 +79,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc
void* p_a_k0_m_k1_grid_desc, void* p_a_k0_m_k1_grid_desc,
void* p_b_k0_n_k1_grid_desc, void* p_b_k0_n_k1_grid_desc,
void* p_c_m0_m1_m2_n_grid_desc, void* p_c_m0_m1_m2_n_grid_desc,
void* p_c_blockid_to_m0_n0_block_cluster_adaptor) void* p_cblockid_to_m0_n0_block_cluster_adaptor)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -188,7 +188,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc ...@@ -188,7 +188,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
auto c_blockid_to_m0_n0_block_cluster_adaptor = auto cblockid_to_m0_n0_block_cluster_adaptor =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
if(hipThreadIdx_x == 0) if(hipThreadIdx_x == 0)
...@@ -199,8 +199,8 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc ...@@ -199,8 +199,8 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc
b_k0_n_k1_grid_desc; b_k0_n_k1_grid_desc;
*static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) = *static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) =
c_m0_m1_m2_n_grid_desc; c_m0_m1_m2_n_grid_desc;
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>( *static_cast<decltype(cblockid_to_m0_n0_block_cluster_adaptor)*>(
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor; p_cblockid_to_m0_n0_block_cluster_adaptor) = cblockid_to_m0_n0_block_cluster_adaptor;
} }
}; };
...@@ -215,7 +215,7 @@ extern "C" __global__ void ...@@ -215,7 +215,7 @@ extern "C" __global__ void
const void CONSTANT* p_a_k0_m_k1_grid_desc, const void CONSTANT* p_a_k0_m_k1_grid_desc,
const void CONSTANT* p_b_k0_n_k1_grid_desc, const void CONSTANT* p_b_k0_n_k1_grid_desc,
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) const void CONSTANT* p_cblockid_to_m0_n0_block_cluster_adaptor)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -325,12 +325,11 @@ extern "C" __global__ void ...@@ -325,12 +325,11 @@ extern "C" __global__ void
constexpr auto c_m0_m1_m2_n_grid_desc_tmp = constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp = constexpr auto cblockid_to_m0_n0_block_cluster_adaptor_tmp =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp); using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp);
using CBlockIdToM0N0BlockClusterAdaptor = using CBlockIdToM0N0BlockClusterAdaptor = decltype(cblockid_to_m0_n0_block_cluster_adaptor_tmp);
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
const auto a_k0_m_k1_grid_desc = const auto a_k0_m_k1_grid_desc =
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc); *reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
...@@ -338,9 +337,9 @@ extern "C" __global__ void ...@@ -338,9 +337,9 @@ extern "C" __global__ void
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc); *reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
const auto c_m0_m1_m2_n_grid_desc = const auto c_m0_m1_m2_n_grid_desc =
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc); *reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
const auto c_blockid_to_m0_n0_block_cluster_adaptor = const auto cblockid_to_m0_n0_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>( *reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); (const void*)p_cblockid_to_m0_n0_block_cluster_adaptor);
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -354,5 +353,5 @@ extern "C" __global__ void ...@@ -354,5 +353,5 @@ extern "C" __global__ void
a_k0_m_k1_grid_desc, a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc, b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc, c_m0_m1_m2_n_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor); cblockid_to_m0_n0_block_cluster_adaptor);
}; };
...@@ -79,7 +79,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky ...@@ -79,7 +79,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky
void* p_a_k0_m_k1_grid_desc, void* p_a_k0_m_k1_grid_desc,
void* p_b_k0_n_k1_grid_desc, void* p_b_k0_n_k1_grid_desc,
void* p_c_m0_m1_m2_n_grid_desc, void* p_c_m0_m1_m2_n_grid_desc,
void* p_c_blockid_to_m0_n0_block_cluster_adaptor) void* p_cblockid_to_m0_n0_block_cluster_adaptor)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -188,7 +188,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky ...@@ -188,7 +188,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
auto c_blockid_to_m0_n0_block_cluster_adaptor = auto cblockid_to_m0_n0_block_cluster_adaptor =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
if(hipThreadIdx_x == 0) if(hipThreadIdx_x == 0)
...@@ -199,8 +199,8 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky ...@@ -199,8 +199,8 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky
b_k0_n_k1_grid_desc; b_k0_n_k1_grid_desc;
*static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) = *static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) =
c_m0_m1_m2_n_grid_desc; c_m0_m1_m2_n_grid_desc;
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>( *static_cast<decltype(cblockid_to_m0_n0_block_cluster_adaptor)*>(
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor; p_cblockid_to_m0_n0_block_cluster_adaptor) = cblockid_to_m0_n0_block_cluster_adaptor;
} }
}; };
...@@ -215,7 +215,7 @@ extern "C" __global__ void ...@@ -215,7 +215,7 @@ extern "C" __global__ void
const void CONSTANT* p_a_k0_m_k1_grid_desc, const void CONSTANT* p_a_k0_m_k1_grid_desc,
const void CONSTANT* p_b_k0_n_k1_grid_desc, const void CONSTANT* p_b_k0_n_k1_grid_desc,
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) const void CONSTANT* p_cblockid_to_m0_n0_block_cluster_adaptor)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -324,12 +324,11 @@ extern "C" __global__ void ...@@ -324,12 +324,11 @@ extern "C" __global__ void
false>; false>;
constexpr auto c_m0_m1_m2_n_grid_desc_tmp = constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp = constexpr auto cblockid_to_m0_n0_block_cluster_adaptor_tmp =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp); using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp);
using CBlockIdToM0N0BlockClusterAdaptor = using CBlockIdToM0N0BlockClusterAdaptor = decltype(cblockid_to_m0_n0_block_cluster_adaptor_tmp);
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
const auto a_k0_m_k1_grid_desc = const auto a_k0_m_k1_grid_desc =
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc); *reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
...@@ -337,9 +336,9 @@ extern "C" __global__ void ...@@ -337,9 +336,9 @@ extern "C" __global__ void
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc); *reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
const auto c_m0_m1_m2_n_grid_desc = const auto c_m0_m1_m2_n_grid_desc =
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc); *reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
const auto c_blockid_to_m0_n0_block_cluster_adaptor = const auto cblockid_to_m0_n0_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>( *reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); (const void*)p_cblockid_to_m0_n0_block_cluster_adaptor);
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -353,5 +352,5 @@ extern "C" __global__ void ...@@ -353,5 +352,5 @@ extern "C" __global__ void
a_k0_m_k1_grid_desc, a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc, b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc, c_m0_m1_m2_n_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor); cblockid_to_m0_n0_block_cluster_adaptor);
}; };
#ifndef CONVOLUTION_UTILITY_HPP
#define CONVOLUTION_UTILITY_HPP
#include <vector>
namespace ck {
namespace tensor_operation {
struct ConvolutionUtility
{
static std::vector<ck::index_t>
ComputeOutputSpatialLengths(std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_strides,
std::vector<ck::index_t> conv_dilations,
std::vector<ck::index_t> in_left_pads,
std::vector<ck::index_t> in_right_pads)
{
if(input_spatial_lengths.size() == 2)
{
assert(filter_spatial_lengths.size() == 2);
assert(conv_strides.size() == 2);
assert(conv_dilations.size() == 2);
assert(in_left_pads.size() == 2);
assert(in_right_pads.size() == 2);
const index_t YEff = (filter_spatial_lengths[0] - 1) * conv_dilations[0] + 1;
const index_t XEff = (filter_spatial_lengths[1] - 1) * conv_dilations[1] + 1;
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
const index_t Ho =
(Hi + in_left_pads[0] + in_right_pads[0] - YEff) / conv_strides[0] + 1;
const index_t Wo =
(Wi + in_left_pads[1] + in_right_pads[1] - XEff) / conv_strides[1] + 1;
return {Ho, Wo};
}
else if(input_spatial_lengths.size() == 3)
{
assert(filter_spatial_lengths.size() == 3);
assert(conv_strides.size() == 3);
assert(conv_dilations.size() == 3);
assert(in_left_pads.size() == 3);
assert(in_right_pads.size() == 3);
const index_t ZEff = (filter_spatial_lengths[0] - 1) * conv_dilations[0] + 1;
const index_t YEff = (filter_spatial_lengths[1] - 1) * conv_dilations[1] + 1;
const index_t XEff = (filter_spatial_lengths[2] - 1) * conv_dilations[2] + 1;
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
const index_t Do =
(Di + in_left_pads[0] + in_right_pads[0] - ZEff) / conv_strides[0] + 1;
const index_t Ho =
(Hi + in_left_pads[1] + in_right_pads[1] - YEff) / conv_strides[1] + 1;
const index_t Wo =
(Wi + in_left_pads[2] + in_right_pads[2] - XEff) / conv_strides[2] + 1;
return {Do, Ho, Wo};
}
else
{
return {};
}
}
};
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -248,7 +248,7 @@ struct DeviceBatchedGemmXdl ...@@ -248,7 +248,7 @@ struct DeviceBatchedGemmXdl
c_grid_desc_g_m_n_); c_grid_desc_g_m_n_);
block_2_ctile_map_ = block_2_ctile_map_ =
GridwiseBatchedGemm::MakeBlock2CTileMap(c_grid_desc_g_m_n_, M01, N01); GridwiseBatchedGemm::MakeDefaultBlock2CTileMap(c_grid_desc_g_m_n_, M01, N01);
} }
} }
...@@ -261,7 +261,7 @@ struct DeviceBatchedGemmXdl ...@@ -261,7 +261,7 @@ struct DeviceBatchedGemmXdl
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
typename GridwiseBatchedGemm::CGridDesc_G_M0_N0_M1_N1_M2_M3_M4_N2 typename GridwiseBatchedGemm::CGridDesc_G_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2_; c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseBatchedGemm::Block2CTileMap block_2_ctile_map_; typename GridwiseBatchedGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -327,7 +327,7 @@ struct DeviceBatchedGemmXdl ...@@ -327,7 +327,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
remove_reference_t<typename GridwiseBatchedGemm::Block2CTileMap>, remove_reference_t<typename GridwiseBatchedGemm::DefaultBlock2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
...@@ -359,7 +359,7 @@ struct DeviceBatchedGemmXdl ...@@ -359,7 +359,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
remove_reference_t<typename GridwiseBatchedGemm::Block2CTileMap>, remove_reference_t<typename GridwiseBatchedGemm::DefaultBlock2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
......
...@@ -590,7 +590,8 @@ struct ...@@ -590,7 +590,8 @@ struct
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c1_grid_desc_m_n_); c1_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
} }
} }
...@@ -614,7 +615,7 @@ struct ...@@ -614,7 +615,7 @@ struct
typename GridwiseGemm:: typename GridwiseGemm::
C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
InElementwiseOperation in_element_op_; InElementwiseOperation in_element_op_;
...@@ -694,7 +695,7 @@ struct ...@@ -694,7 +695,7 @@ struct
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
...@@ -738,7 +739,7 @@ struct ...@@ -738,7 +739,7 @@ struct
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
......
...@@ -561,7 +561,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -561,7 +561,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c0_grid_desc_m_n_); c0_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
} }
} }
...@@ -579,7 +580,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -579,7 +580,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
typename GridwiseGemm:: typename GridwiseGemm::
C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
InElementwiseOperation in_element_op_; InElementwiseOperation in_element_op_;
...@@ -653,7 +654,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -653,7 +654,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
...@@ -692,7 +693,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -692,7 +693,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
......
...@@ -525,7 +525,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -525,7 +525,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c_grid_desc_m_n_); c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
} }
} }
...@@ -538,7 +539,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -538,7 +539,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
typename GridwiseGemm:: typename GridwiseGemm::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
InElementwiseOperation in_element_op_; InElementwiseOperation in_element_op_;
...@@ -628,7 +629,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -628,7 +629,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
...@@ -662,7 +663,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -662,7 +663,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
......
...@@ -415,7 +415,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -415,7 +415,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
} }
} }
...@@ -428,7 +429,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -428,7 +429,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
InElementwiseOperation in_element_op_; InElementwiseOperation in_element_op_;
...@@ -471,7 +472,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -471,7 +472,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
arg.N01_)) arg.N01_))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
} }
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
...@@ -494,7 +495,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -494,7 +495,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
...@@ -525,7 +526,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -525,7 +526,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
......
#ifndef DEVICE_CONV3D_FWD_NAIVE_HPP
#define DEVICE_CONV3D_FWD_NAIVE_HPP
#include <iostream>
#include <memory>
#include <sstream>
#include "convolution_utility.hpp"
#include "device.hpp"
#include "device_conv_fwd.hpp"
#include "common_header.hpp"
#include "naive_conv_fwd.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// specialization for #D conv: in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k]
template <typename InDataType,
typename WeiDataType, // WeiDataType must be the same as InDataType
typename OutDataType,
typename AccDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
: public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
{
using DeviceOp = DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K;
using ADataType = InDataType;
using BDataType = WeiDataType;
using CDataType = OutDataType;
// TODO make A/B datatype different
using ABDataType = InDataType;
// Argument
struct Argument : public BaseArgument
{
Argument(const InDataType* p_in,
const WeiDataType* p_wei,
OutDataType* p_out,
const index_t N,
const index_t K,
const index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
: N_{N},
K_{K},
C_{C},
in_spatial_lengths_{input_spatial_lengths},
filter_spatial_lengths_{filter_spatial_lengths},
out_spatial_lengths_{output_spatial_lengths},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads},
in_right_pads_{input_right_pads},
p_in_{p_in},
p_wei_{p_wei},
p_out_{p_out},
in_element_op_{in_element_op},
wei_element_op_{wei_element_op},
out_element_op_{out_element_op}
{
}
// private:
index_t N_;
index_t K_;
index_t C_;
std::vector<index_t> in_spatial_lengths_;
std::vector<index_t> filter_spatial_lengths_;
std::vector<index_t> out_spatial_lengths_;
std::vector<index_t> conv_filter_strides_;
std::vector<index_t> conv_filter_dilations_;
std::vector<index_t> in_left_pads_;
std::vector<index_t> in_right_pads_;
const InDataType* p_in_;
const WeiDataType* p_wei_;
OutDataType* p_out_;
InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_;
OutElementwiseOperation out_element_op_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1)
{
const auto naive_conv3d_fwd =
ref::naive_conv_fwd_ndhwc_kzyxc_ndhwk<InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>;
float ave_time = launch_and_time_kernel(naive_conv3d_fwd,
nrepeat,
dim3(256),
dim3(256),
0,
arg.p_in_,
arg.p_wei_,
arg.p_out_,
arg.N_,
arg.K_,
arg.C_,
arg.in_spatial_lengths_[0],
arg.in_spatial_lengths_[1],
arg.in_spatial_lengths_[2],
arg.filter_spatial_lengths_[0],
arg.filter_spatial_lengths_[1],
arg.filter_spatial_lengths_[2],
arg.out_spatial_lengths_[0],
arg.out_spatial_lengths_[1],
arg.out_spatial_lengths_[2],
arg.conv_filter_strides_[0],
arg.conv_filter_strides_[1],
arg.conv_filter_strides_[2],
arg.conv_filter_dilations_[0],
arg.conv_filter_dilations_[1],
arg.conv_filter_dilations_[2],
arg.in_left_pads_[0],
arg.in_left_pads_[1],
arg.in_left_pads_[2]);
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
std::vector<index_t> out_spatial_lengths =
ConvolutionUtility::ComputeOutputSpatialLengths(arg.in_spatial_lengths_,
arg.filter_spatial_lengths_,
arg.conv_filter_strides_,
arg.conv_filter_dilations_,
arg.in_left_pads_,
arg.in_right_pads_);
bool out_lengths_are_consistent = out_spatial_lengths[0] == arg.out_spatial_lengths_[0] &&
out_spatial_lengths[1] == arg.out_spatial_lengths_[1] &&
out_spatial_lengths[2] == arg.out_spatial_lengths_[2];
return out_lengths_are_consistent;
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const InDataType* p_in,
const WeiDataType* p_wei,
OutDataType* p_out,
const index_t N,
const index_t K,
const index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
{
return Argument{p_in,
p_wei,
p_out,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in,
const void* p_wei,
void* p_out,
const index_t N,
const index_t K,
const index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) override
{
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in),
static_cast<const WeiDataType*>(p_wei),
static_cast<OutDataType*>(p_out),
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K<>";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -261,7 +261,8 @@ struct DeviceGemmXdl ...@@ -261,7 +261,8 @@ struct DeviceGemmXdl
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
} }
} }
...@@ -274,7 +275,7 @@ struct DeviceGemmXdl ...@@ -274,7 +275,7 @@ struct DeviceGemmXdl
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -309,7 +310,7 @@ struct DeviceGemmXdl ...@@ -309,7 +310,7 @@ struct DeviceGemmXdl
arg.N01_)) arg.N01_))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
} }
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
...@@ -332,7 +333,7 @@ struct DeviceGemmXdl ...@@ -332,7 +333,7 @@ struct DeviceGemmXdl
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
...@@ -363,7 +364,7 @@ struct DeviceGemmXdl ...@@ -363,7 +364,7 @@ struct DeviceGemmXdl
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
......
...@@ -221,7 +221,8 @@ struct DeviceGemmXdl_C_Shuffle ...@@ -221,7 +221,8 @@ struct DeviceGemmXdl_C_Shuffle
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c_grid_desc_m_n_); c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
} }
} }
...@@ -235,7 +236,7 @@ struct DeviceGemmXdl_C_Shuffle ...@@ -235,7 +236,7 @@ struct DeviceGemmXdl_C_Shuffle
typename GridwiseGemm:: typename GridwiseGemm::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -295,7 +296,7 @@ struct DeviceGemmXdl_C_Shuffle ...@@ -295,7 +296,7 @@ struct DeviceGemmXdl_C_Shuffle
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
...@@ -329,7 +330,7 @@ struct DeviceGemmXdl_C_Shuffle ...@@ -329,7 +330,7 @@ struct DeviceGemmXdl_C_Shuffle
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
......
...@@ -235,7 +235,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d ...@@ -235,7 +235,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c_grid_desc_m_n_); c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
} }
} }
...@@ -254,7 +255,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d ...@@ -254,7 +255,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
typename GridwiseGemm:: typename GridwiseGemm::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -320,7 +321,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d ...@@ -320,7 +321,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
...@@ -359,7 +360,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d ...@@ -359,7 +360,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
......
...@@ -240,7 +240,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation ...@@ -240,7 +240,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c0_grid_desc_m_n_); c0_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
} }
} }
...@@ -259,7 +260,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation ...@@ -259,7 +260,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
typename GridwiseGemm:: typename GridwiseGemm::
C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -325,7 +326,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation ...@@ -325,7 +326,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
...@@ -364,7 +365,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation ...@@ -364,7 +365,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
......
...@@ -274,7 +274,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add ...@@ -274,7 +274,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c1_grid_desc_m_n_); c1_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
} }
} }
...@@ -298,7 +299,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add ...@@ -298,7 +299,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
typename GridwiseGemm:: typename GridwiseGemm::
C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -370,7 +371,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add ...@@ -370,7 +371,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
...@@ -414,7 +415,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add ...@@ -414,7 +415,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
......
...@@ -45,6 +45,18 @@ struct NKHW : public BaseTensorLayout ...@@ -45,6 +45,18 @@ struct NKHW : public BaseTensorLayout
{ {
}; };
struct NDHWC : public BaseTensorLayout
{
};
struct KZYXC : public BaseTensorLayout
{
};
struct NDHWK : public BaseTensorLayout
{
};
} // namespace convolution } // namespace convolution
} // namespace tensor_layout } // namespace tensor_layout
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment