Commit 95a83c6e authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge remote-tracking branch 'origin/develop' into wavelet_model

parents 5b7c2432 892a8d76
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
namespace ck {
enum struct PipelineVersion
{
v1,
v2,
};
template <PipelineVersion PipelineVer,
index_t NumPrefetch = 1,
LoopScheduler LoopSched = LoopScheduler::Default>
constexpr auto GridwiseGemmPipeline_Selector()
{
if constexpr(PipelineVer == PipelineVersion::v1)
{
if constexpr(LoopSched == LoopScheduler::Default)
{
return GridwiseGemmPipeline_v1<NumPrefetch>{};
}
else if constexpr(LoopSched == LoopScheduler::Interwave)
{
return GridwiseGemmPipelineInterwave_v1<NumPrefetch>{};
}
}
else if constexpr(PipelineVer == PipelineVersion::v2)
{
return GridwiseGemmPipeline_v2{};
}
else
{
std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl;
}
}
} // namespace ck
...@@ -352,6 +352,7 @@ struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2> ...@@ -352,6 +352,7 @@ struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2>
{ {
}; };
// TODO: deprecate as GridwiseGemmPipeline_Selector covers the functionality
template <index_t NumPrefetch, LoopScheduler LoopSched> template <index_t NumPrefetch, LoopScheduler LoopSched>
constexpr auto GridwiseGemmPipeline_v1_Selector() constexpr auto GridwiseGemmPipeline_v1_Selector()
{ {
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
...@@ -142,7 +142,8 @@ template <typename FloatAB, ...@@ -142,7 +142,8 @@ template <typename FloatAB,
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock, typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched> LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -162,7 +163,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -162,7 +163,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>; using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
...@@ -481,7 +483,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -481,7 +483,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// gridwise GEMM pipeline // gridwise GEMM pipeline
const auto gridwise_gemm_pipeline = const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1_Selector<NumGemmKPrefetchStage, LoopSched>(); GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
......
...@@ -8,8 +8,7 @@ ...@@ -8,8 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
...@@ -115,7 +114,8 @@ template <typename FloatAB, ...@@ -115,7 +114,8 @@ template <typename FloatAB,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched> LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -136,13 +136,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -136,13 +136,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
using GridwiseGemmPipe = using GridwiseGemmPipe = remove_cvref_t<decltype(
#if 1 GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
remove_cvref_t<decltype(
GridwiseGemmPipeline_v1_Selector<NumGemmKPrefetchStage, LoopSched>())>;
#else
GridwiseGemmPipeline_v2;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
...@@ -151,7 +151,8 @@ template <typename FloatAB, ...@@ -151,7 +151,8 @@ template <typename FloatAB,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock, typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
index_t CReduceThreadCopySrcDstScalarPerVector_NPerBlock, index_t CReduceThreadCopySrcDstScalarPerVector_NPerBlock,
LoopScheduler LoopSched> LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -171,7 +172,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -171,7 +172,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>; using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
...@@ -519,7 +521,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -519,7 +521,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// gridwise GEMM pipeline // gridwise GEMM pipeline
const auto gridwise_gemm_pipeline = const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1_Selector<NumGemmKPrefetchStage, LoopSched>(); GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
...@@ -243,7 +243,8 @@ template <index_t BlockSize, ...@@ -243,7 +243,8 @@ template <index_t BlockSize,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
bool ABlockLdsExtraM1Wrw = false, bool ABlockLdsExtraM1Wrw = false,
bool BBlockLdsExtraN1Wrw = false, bool BBlockLdsExtraN1Wrw = false,
index_t NumGemmKPrefetchStage = 1> index_t NumGemmKPrefetchStage = 1,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -258,8 +259,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -258,8 +259,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{}; static constexpr auto K1 = Number<K1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
// M0/M1/M1Padding // M0/M1/M1Padding
static constexpr auto M1PerBlock = Number<ABlockLdsM1PerBlock>{}; static constexpr auto M1PerBlock = Number<ABlockLdsM1PerBlock>{};
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
...@@ -109,7 +109,9 @@ template <index_t BlockSize, ...@@ -109,7 +109,9 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
index_t NumGemmKPrefetchStage = 1> index_t NumGemmKPrefetchStage = 1,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -126,7 +128,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -126,7 +128,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>; using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{ {
...@@ -423,18 +426,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -423,18 +426,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
auto blockwise_gemm = BlockSize,
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, FloatAB,
FloatAB, FloatAcc,
FloatAcc, decltype(a_block_desc_k0_m_k1),
decltype(a_block_desc_k0_m_k1), decltype(b_block_desc_k0_n_k1),
decltype(b_block_desc_k0_n_k1), MPerXDL,
MPerXDL, NPerXDL,
NPerXDL, MXdlPerWave,
MXdlPerWave, NXdlPerWave,
NXdlPerWave, K1,
K1>{}; LoopSched>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
...@@ -117,7 +117,8 @@ template < ...@@ -117,7 +117,8 @@ template <
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl, index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
index_t NumGemmKPrefetchStage = 1> index_t NumGemmKPrefetchStage = 1,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -137,7 +138,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -137,7 +138,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>; using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp"
...@@ -123,7 +123,8 @@ template < ...@@ -123,7 +123,8 @@ template <
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl, index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
index_t NumGemmKPrefetchStage = 1> index_t NumGemmKPrefetchStage = 1,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -140,7 +141,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -140,7 +141,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>; using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{ {
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp"
...@@ -132,7 +132,8 @@ template < ...@@ -132,7 +132,8 @@ template <
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl, index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
index_t NumGemmKPrefetchStage = 1> index_t NumGemmKPrefetchStage = 1,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -149,7 +150,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -149,7 +150,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>; using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{ {
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
namespace ck { namespace ck {
// Y = LayerNorm(X, Beta, Gamma) // Y = Normalization(X, Beta, Gamma)
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
...@@ -36,7 +36,7 @@ template <typename XDataType, ...@@ -36,7 +36,7 @@ template <typename XDataType,
index_t YDstVectorDim, index_t YDstVectorDim,
index_t YDstVectorSize, index_t YDstVectorSize,
bool SweepOnce> bool SweepOnce>
struct GridwiseLayernormNaiveVariance_mk_to_mk struct GridwiseNormalizationNaiveVariance_mk_to_mk
{ {
static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
namespace ck { namespace ck {
// Y = LayerNorm(X, Beta, Gamma) // Y = Normalization(X, Beta, Gamma)
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
...@@ -33,7 +33,7 @@ template <typename XDataType, ...@@ -33,7 +33,7 @@ template <typename XDataType,
index_t YDstVectorDim, index_t YDstVectorDim,
index_t YDstVectorSize, index_t YDstVectorSize,
bool SweepOnce> bool SweepOnce>
struct GridwiseLayernormWelfordVariance_mk_to_mk struct GridwiseNormalizationWelfordVariance_mk_to_mk
{ {
static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#pragma once #pragma once
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace ck { namespace ck {
......
...@@ -75,4 +75,63 @@ struct ThreadwiseWelford ...@@ -75,4 +75,63 @@ struct ThreadwiseWelford
int max_count_; int max_count_;
}; };
template <typename T,
typename SrcMeanVarCountThreadDesc_M_K,
typename DstMeanVarThreadDesc_M,
bool GetActualVariance = false>
struct ThreadwiseWelfordMerge
{
static constexpr auto src_thread_desc_m_k = SrcMeanVarCountThreadDesc_M_K{};
static constexpr auto dst_thread_desc_m = DstMeanVarThreadDesc_M{};
static constexpr auto src_length_m = src_thread_desc_m_k.GetLength(Number<0>{});
static constexpr auto src_length_k = src_thread_desc_m_k.GetLength(Number<1>{});
static constexpr auto dst_length_m = dst_thread_desc_m.GetLength(Number<0>{});
static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
__device__ static void
Merge(T& mean_a, T& var_a, int32_t& count_a, T mean_b, T var_b, int32_t count_b)
{
int count = count_a + count_b;
T count_b_over_count = count == 0 ? type_convert<T>(0) : type_convert<T>(count_b) / count;
T delta = mean_b - mean_a;
mean_a += delta * count_b_over_count;
var_a += var_b + delta * delta * count_a * count_b_over_count;
count_a = count;
}
template <typename SrcMeanBufferType,
typename SrcVarBufferType,
typename SrcCountBufferType,
typename DstMeanBufferType,
typename DstVarBufferType,
typename DstCountBufferType>
__device__ static void Run(const SrcMeanBufferType& src_mean_buf,
const SrcVarBufferType& src_var_buf,
const SrcCountBufferType& src_count_buf,
DstMeanBufferType& dst_mean_buf,
DstVarBufferType& dst_var_buf,
DstCountBufferType& dst_count_buf)
{
static_for<0, src_length_m, 1>{}([&](auto iM) {
static_for<0, src_length_k, 1>{}([&](auto iK) {
constexpr auto src_offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK));
Merge(dst_mean_buf(iM),
dst_var_buf(iM),
dst_count_buf(iM),
src_mean_buf[Number<src_offset>{}],
src_var_buf[Number<src_offset>{}],
src_count_buf[Number<src_offset>{}]);
});
if constexpr(GetActualVariance)
{
dst_var_buf(iM) = dst_var_buf[iM] / dst_count_buf[iM];
};
});
};
};
} // namespace ck } // namespace ck
...@@ -593,7 +593,8 @@ struct XdlopsGemm ...@@ -593,7 +593,8 @@ struct XdlopsGemm
static constexpr auto I4 = Number<4>{}; static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
using CIndex = MultiIndex<2>; using CIndex = MultiIndex<2>;
using CIndex4D = MultiIndex<4>;
__device__ static constexpr index_t GetNumBlks() { return mfma_instr.num_output_blks; } __device__ static constexpr index_t GetNumBlks() { return mfma_instr.num_output_blks; }
...@@ -822,6 +823,16 @@ struct XdlopsGemm ...@@ -822,6 +823,16 @@ struct XdlopsGemm
return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset}; return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
} }
__device__ static CIndex4D GetBeginOfThreadBlk4D(index_t /* xdlops_i */, index_t /* blk_i */)
{
const auto blk_idx = GetBlkIdx();
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td};
}
static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops>{}; static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops>{};
static constexpr auto mfma_instr = mfma.selected_mfma; static constexpr auto mfma_instr = mfma.selected_mfma;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
namespace ck {
namespace tensor_operation {
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
device::TensorSpecialization TensorSpec>
static auto MakeGridDescriptorPair(const std::vector<index_t>& gs_ms_ns_lengths_vec,
const std::vector<index_t>& gs_ms_ns_strides_vec)
{
if(!(gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN))
{
throw std::runtime_error("wrong! dimension must match input lengths");
}
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
};
const auto gs_ms_ns_lengths =
to_tuple(gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto gs_ms_ns_strides =
to_tuple(gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
// dimension Ids for G0, G1, ...
constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
// dimension Ids for M0, M1, ...
constexpr auto mDimIds =
typename arithmetic_sequence_gen<NumDimG, NumDimG + NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds =
typename arithmetic_sequence_gen<NumDimG + NumDimM, NumDimG + NumDimM + NumDimN, 1>::type{};
// lengths for G0, G1, ...
const auto gLengths = get_container_subset(gs_ms_ns_lengths, gDimIds);
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(gs_ms_ns_lengths, mDimIds);
// lengths for N0, N1, ...
const auto nLengths = get_container_subset(gs_ms_ns_lengths, nDimIds);
if constexpr(TensorSpec == device::TensorSpecialization::Packed)
{
auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{});
auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
const auto grid_desc_g_mraw_nraw = make_naive_tensor_descriptor(
make_tuple(G, M, N),
make_tuple(gs_ms_ns_strides[Number<NumDimG - 1>{}],
gs_ms_ns_strides[Number<NumDimG + NumDimM - 1>{}],
gs_ms_ns_strides[Number<NumDimG + NumDimM + NumDimN - 1>{}]));
const auto grid_desc_mraw_nraw = make_naive_tensor_descriptor(
make_tuple(M, N),
make_tuple(gs_ms_ns_strides[Number<NumDimG + NumDimM - 1>{}],
gs_ms_ns_strides[Number<NumDimG + NumDimM + NumDimN - 1>{}]));
return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw);
}
else
{
// naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
const auto grid_desc_gs_ms_ns =
make_naive_tensor_descriptor(gs_ms_ns_lengths, gs_ms_ns_strides);
// transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
// Note: This does not require padding as it only provides G offset calculation. Technically
// descriptor for only G is needed. Here we opt for backward compatibility purpose to return
// G_M_N
const auto grid_desc_g_mraw_nraw =
transform_tensor_descriptor(grid_desc_gs_ms_ns,
make_tuple(make_merge_transform(gLengths),
make_merge_transform(mLengths),
make_merge_transform(nLengths)),
make_tuple(gDimIds, mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto c_ms_ns_lengths = to_tuple(
gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto c_ms_ns_strides = to_tuple(
gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
const auto grid_desc_ms_ns = make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides);
const auto grid_desc_mraw_nraw = transform_tensor_descriptor(
grid_desc_ms_ns,
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
make_tuple(mDimIds - Number<NumDimG>{}, nDimIds - Number<NumDimG>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw);
}
}
template <typename NumDims_G_M_N_K_O, // Sequence<>
typename PerBlock_M_N_K_O, // Sequence<>
device::GemmSpecialization GemmSpec,
device::TensorSpecialization ASpec,
device::TensorSpecialization B0Spec,
device::TensorSpecialization B1Spec,
device::TensorSpecialization CSpec>
struct TransformBatchedContractionContractionToBatchedGemmGemm
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr index_t NumDimG = NumDims_G_M_N_K_O::At(I0);
static constexpr index_t NumDimM = NumDims_G_M_N_K_O::At(I1);
static constexpr index_t NumDimN = NumDims_G_M_N_K_O::At(I2);
static constexpr index_t NumDimK = NumDims_G_M_N_K_O::At(I3);
static constexpr index_t NumDimO = NumDims_G_M_N_K_O::At(I4);
static constexpr index_t MPerBlock = PerBlock_M_N_K_O::At(I0);
static constexpr index_t NPerBlock = PerBlock_M_N_K_O::At(I1);
static constexpr index_t KPerBlock = PerBlock_M_N_K_O::At(I2);
static constexpr index_t OPerBlock = PerBlock_M_N_K_O::At(I3);
static constexpr auto matrix_padder =
device::GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock, OPerBlock};
//
// A
//
static auto MakeAGridDescriptorPair(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
{
return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimK, ASpec>(a_gs_ms_ks_lengths_vec,
a_gs_ms_ks_strides_vec);
}
// TODO: rename to G_MRaw_KRaw
static auto MakeAGridDescriptor_G_M_K(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
{
return MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).first;
}
static auto MakeAGridDescriptor_M_K(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
{
return matrix_padder.PadADescriptor_M_K(
MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).second);
}
template <typename AGridDesc_M_K, typename Number>
__host__ __device__ static constexpr auto
MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k, const Number& AK1)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
//
// B (alias of B0)
//
static auto MakeB0GridDescriptorPair(const std::vector<index_t>& b0_gs_ns_ks_lengths_vec,
const std::vector<index_t>& b0_gs_ns_ks_strides_vec)
{
return MakeGridDescriptorPair<NumDimG, NumDimN, NumDimK, B0Spec>(b0_gs_ns_ks_lengths_vec,
b0_gs_ns_ks_strides_vec);
}
// TODO: rename to G_MRaw_NRaw
static auto MakeB0GridDescriptor_G_N_K(const std::vector<index_t>& b0_gs_ns_ks_lengths_vec,
const std::vector<index_t>& b0_gs_ns_ks_strides_vec)
{
return MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).first;
}
static auto MakeB0GridDescriptor_N_K(const std::vector<index_t>& b0_gs_ns_ks_lengths_vec,
const std::vector<index_t>& b0_gs_ns_ks_strides_vec)
{
// alias of matrix_padder.PadB0Descriptor_N_K
return matrix_padder.PadBDescriptor_N_K(
MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).second);
}
template <typename BGridDesc_N_K, typename Number>
__host__ __device__ static constexpr auto
MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k, const Number& BK1)
{
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
//
// B1
//
static auto MakeB1GridDescriptorPair(const std::vector<index_t>& b1_gs_os_ns_lengths_vec,
const std::vector<index_t>& b1_gs_os_ns_strides_vec)
{
return MakeGridDescriptorPair<NumDimG, NumDimO, NumDimN, B1Spec>(b1_gs_os_ns_lengths_vec,
b1_gs_os_ns_strides_vec);
}
// TODO: rename to G_NRaw_KRaw
static auto MakeB1GridDescriptor_G_N_K(const std::vector<index_t>& b1_gs_os_ns_lengths_vec,
const std::vector<index_t>& b1_gs_os_ns_strides_vec)
{
return MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).first;
}
static auto MakeB1GridDescriptor_N_K(const std::vector<index_t>& b1_gs_os_ns_lengths_vec,
const std::vector<index_t>& b1_gs_os_ns_strides_vec)
{
// alias of matrix_padder.PadB1Descriptor_O_N
return matrix_padder.PadB1Descriptor_N_K(
MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).second);
}
template <typename B1GridDesc_N_K, typename Number>
__host__ __device__ static constexpr auto
MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K& b1_grid_desc_n_k, const Number& B1K1)
{
const auto N = b1_grid_desc_n_k.GetLength(I0);
const auto K = b1_grid_desc_n_k.GetLength(I1);
const auto B1K0 = K / B1K1;
return transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
//
// C
//
static auto MakeCGridDescriptorPair(const std::vector<index_t>& c_gs_ms_os_lengths_vec,
const std::vector<index_t>& c_gs_ms_os_strides_vec)
{
return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimO, CSpec>(c_gs_ms_os_lengths_vec,
c_gs_ms_os_strides_vec);
}
// TODO: rename to G_MRaw_NRaw
static auto MakeCGridDescriptor_G_M_N(const std::vector<index_t>& c_gs_ms_os_lengths_vec,
const std::vector<index_t>& c_gs_ms_os_strides_vec)
{
return MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).first;
}
static auto MakeCGridDescriptor_M_N(const std::vector<index_t>& c_gs_ms_os_lengths_vec,
const std::vector<index_t>& c_gs_ms_os_strides_vec)
{
return matrix_padder.PadCDescriptor_M_N(
MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).second);
}
};
} // namespace tensor_operation
} // namespace ck
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck/library/utility/numeric.hpp"
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
...@@ -47,10 +48,9 @@ struct TransformConvFwdToGemm ...@@ -47,10 +48,9 @@ struct TransformConvFwdToGemm
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
const index_t NWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, const index_t NWo =
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, N * ck::accumulate_n<index_t>(
index_t{1}, c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
std::multiplies<index_t>());
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor_packed(make_tuple(NWo, C)); make_naive_tensor_descriptor_packed(make_tuple(NWo, C));
...@@ -146,10 +146,9 @@ struct TransformConvFwdToGemm ...@@ -146,10 +146,9 @@ struct TransformConvFwdToGemm
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
const index_t NHoWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, const index_t NHoWo =
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, N * ck::accumulate_n<index_t>(
index_t{1}, c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
std::multiplies<index_t>());
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor_packed(make_tuple(NHoWo, C)); make_naive_tensor_descriptor_packed(make_tuple(NHoWo, C));
...@@ -262,10 +261,8 @@ struct TransformConvFwdToGemm ...@@ -262,10 +261,8 @@ struct TransformConvFwdToGemm
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
const index_t NDoHoWo = const index_t NDoHoWo =
N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
index_t{1},
std::multiplies<index_t>());
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor_packed(make_tuple(NDoHoWo, C)); make_naive_tensor_descriptor_packed(make_tuple(NDoHoWo, C));
...@@ -390,10 +387,9 @@ struct TransformConvFwdToGemm ...@@ -390,10 +387,9 @@ struct TransformConvFwdToGemm
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
const index_t NHoWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, const index_t NHoWo =
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, N * ck::accumulate_n<index_t>(
index_t{1}, c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
std::multiplies<index_t>());
// This is different // This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
...@@ -506,10 +502,9 @@ struct TransformConvFwdToGemm ...@@ -506,10 +502,9 @@ struct TransformConvFwdToGemm
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
const index_t NHoWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, const index_t NHoWo =
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, N * ck::accumulate_n<index_t>(
index_t{1}, c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
std::multiplies<index_t>());
// This is different // This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
...@@ -639,10 +634,8 @@ struct TransformConvFwdToGemm ...@@ -639,10 +634,8 @@ struct TransformConvFwdToGemm
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
const index_t NDoHoWo = const index_t NDoHoWo =
N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
index_t{1},
std::multiplies<index_t>());
// This is different // This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
...@@ -768,10 +761,8 @@ struct TransformConvFwdToGemm ...@@ -768,10 +761,8 @@ struct TransformConvFwdToGemm
const index_t K = b_g_k_c_xs_lengths[1]; const index_t K = b_g_k_c_xs_lengths[1];
const index_t C = b_g_k_c_xs_lengths[2]; const index_t C = b_g_k_c_xs_lengths[2];
const index_t YX = std::accumulate(b_g_k_c_xs_lengths.begin() + 3, const index_t YX = ck::accumulate_n<index_t>(
b_g_k_c_xs_lengths.begin() + 3 + NDimSpatial, b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
index_t{1},
std::multiplies<index_t>());
const auto wei_gemmn_gemmk_desc = const auto wei_gemmn_gemmk_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, YX * C)); make_naive_tensor_descriptor_packed(make_tuple(K, YX * C));
...@@ -794,10 +785,8 @@ struct TransformConvFwdToGemm ...@@ -794,10 +785,8 @@ struct TransformConvFwdToGemm
const index_t K = b_g_k_c_xs_lengths[1]; const index_t K = b_g_k_c_xs_lengths[1];
const index_t C = b_g_k_c_xs_lengths[2]; const index_t C = b_g_k_c_xs_lengths[2];
const index_t YX = std::accumulate(b_g_k_c_xs_lengths.begin() + 3, const index_t YX = ck::accumulate_n<index_t>(
b_g_k_c_xs_lengths.begin() + 3 + NDimSpatial, b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
index_t{1},
std::multiplies<index_t>());
const index_t KStride = b_g_k_c_xs_strides[1]; const index_t KStride = b_g_k_c_xs_strides[1];
const index_t XStride = b_g_k_c_xs_strides[2 + NDimSpatial]; const index_t XStride = b_g_k_c_xs_strides[2 + NDimSpatial];
...@@ -827,10 +816,9 @@ struct TransformConvFwdToGemm ...@@ -827,10 +816,9 @@ struct TransformConvFwdToGemm
const index_t N = c_g_n_k_wos_lengths[1]; const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2]; const index_t K = c_g_n_k_wos_lengths[2];
const index_t NHoWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, const index_t NHoWo =
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, N * ck::accumulate_n<index_t>(
index_t{1}, c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
std::multiplies<index_t>());
const auto out_gemmm_gemmn_desc = make_naive_tensor_descriptor_packed(make_tuple(NHoWo, K)); const auto out_gemmm_gemmn_desc = make_naive_tensor_descriptor_packed(make_tuple(NHoWo, K));
...@@ -855,10 +843,9 @@ struct TransformConvFwdToGemm ...@@ -855,10 +843,9 @@ struct TransformConvFwdToGemm
const auto KStride = I1; const auto KStride = I1;
const index_t WoStride = c_g_n_k_wos_strides[NDimSpatial + 2]; const index_t WoStride = c_g_n_k_wos_strides[NDimSpatial + 2];
const index_t NHoWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, const index_t NHoWo =
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, N * ck::accumulate_n<index_t>(
index_t{1}, c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
std::multiplies<index_t>());
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(WoStride, KStride)); make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(WoStride, KStride));
...@@ -878,10 +865,9 @@ struct TransformConvFwdToGemm ...@@ -878,10 +865,9 @@ struct TransformConvFwdToGemm
const index_t N = c_g_n_k_wos_lengths[1]; const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2]; const index_t K = c_g_n_k_wos_lengths[2];
const index_t NHoWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, const index_t NHoWo =
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, N * ck::accumulate_n<index_t>(
index_t{1}, c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
std::multiplies<index_t>());
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, I1)); make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, I1));
......
...@@ -254,7 +254,7 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16> ...@@ -254,7 +254,7 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c) __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8bf16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
} }
}; };
......
...@@ -9,46 +9,61 @@ ...@@ -9,46 +9,61 @@
#include <algorithm> #include <algorithm>
#include <thread> #include <thread>
#include "ck/utility/math_v2.hpp"
#include "ck/utility/ignore.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp" #include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace host { namespace host {
template <typename InOutDataType, typename AccDataType> template <typename XDataType,
struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatchNormFwd<4, 3> typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp>
struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
: public device::DeviceBatchNormFwd<4, 3, YElementwiseOp>
{ {
struct Argument : public device::BaseArgument struct Argument : public device::BaseArgument
{ {
Argument(const std::array<index_t, 4> xyLengths, Argument(const std::array<index_t, 4> xyLengths,
const std::array<index_t, 4> xStrides, const std::array<index_t, 4> xStrides,
const std::array<index_t, 4> yStrides, const std::array<index_t, 4> yStrides,
const std::array<int, 3> reduceDims,
const std::array<index_t, 1> bnScaleBiasMeanVarLengths, const std::array<index_t, 1> bnScaleBiasMeanVarLengths,
const std::array<index_t, 1> bnScaleBiasMeanVarStrides, const std::array<index_t, 1> bnScaleStrides,
const InOutDataType* p_x, const std::array<index_t, 1> bnBiasStrides,
const AccDataType* bnScale, const std::array<index_t, 1> bnMeanVarStrides,
const AccDataType* bnBias, const XDataType* p_x,
InOutDataType* p_y, const ScaleDataType* bnScale,
double exponentialAverageFactor, const BiasDataType* bnBias,
AccDataType* resultRunningMean,
AccDataType* resultRunningVariance,
double epsilon, double epsilon,
AccDataType* resultSaveMean, const YElementwiseOp y_elementwise_op,
AccDataType* resultSaveInvVariance) YDataType* p_y,
MeanVarDataType* resultSaveMean,
MeanVarDataType* resultSaveInvVariance,
double averageFactor,
MeanVarDataType* resultRunningMean,
MeanVarDataType* resultRunningVariance)
: p_x_(p_x), : p_x_(p_x),
bnScale_(bnScale), bnScale_(bnScale),
bnBias_(bnBias), bnBias_(bnBias),
y_elementwise_op_(y_elementwise_op),
p_y_(p_y), p_y_(p_y),
resultRunningMean_(resultRunningMean),
resultRunningVariance_(resultRunningVariance),
resultSaveMean_(resultSaveMean), resultSaveMean_(resultSaveMean),
resultSaveInvVariance_(resultSaveInvVariance), resultSaveInvVariance_(resultSaveInvVariance),
exponentialAverageFactor_(exponentialAverageFactor), resultRunningMean_(resultRunningMean),
epsilon_(epsilon) resultRunningVariance_(resultRunningVariance)
{ {
(void)xStrides; ignore = xStrides;
(void)yStrides; ignore = yStrides;
(void)bnScaleBiasMeanVarStrides; ignore = bnScaleStrides;
ignore = bnBiasStrides;
ignore = bnMeanVarStrides;
ignore = reduceDims;
if(xyLengths.size() != 4 || bnScaleBiasMeanVarLengths.size() != 1 || if(xyLengths.size() != 4 || bnScaleBiasMeanVarLengths.size() != 1 ||
bnScaleBiasMeanVarLengths[0] != xyLengths[3]) bnScaleBiasMeanVarLengths[0] != xyLengths[3])
...@@ -59,26 +74,30 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch ...@@ -59,26 +74,30 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
w = xyLengths[2]; w = xyLengths[2];
c = xyLengths[3]; c = xyLengths[3];
epsilon_ = type_convert<AccDataType>(epsilon);
averageFactor_ = type_convert<AccDataType>(averageFactor);
resultSave = (resultSaveMean != nullptr && resultSaveInvVariance != nullptr); resultSave = (resultSaveMean != nullptr && resultSaveInvVariance != nullptr);
resultRunning = (resultRunningMean != nullptr && resultRunningVariance != nullptr); resultRunning = (resultRunningMean != nullptr && resultRunningVariance != nullptr);
} }
const InOutDataType* p_x_; const XDataType* p_x_;
const AccDataType* bnScale_; const ScaleDataType* bnScale_;
const AccDataType* bnBias_; const BiasDataType* bnBias_;
InOutDataType* p_y_; const YElementwiseOp y_elementwise_op_;
YDataType* p_y_;
AccDataType* resultRunningMean_; MeanVarDataType* resultSaveMean_;
AccDataType* resultRunningVariance_; MeanVarDataType* resultSaveInvVariance_;
AccDataType* resultSaveMean_; MeanVarDataType* resultRunningMean_;
AccDataType* resultSaveInvVariance_; MeanVarDataType* resultRunningVariance_;
bool resultSave, resultRunning; bool resultSave, resultRunning;
index_t n, h, w, c; index_t n, h, w, c;
double exponentialAverageFactor_; AccDataType averageFactor_;
double epsilon_; AccDataType epsilon_;
}; };
struct Invoker : public device::BaseInvoker struct Invoker : public device::BaseInvoker
...@@ -86,14 +105,12 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch ...@@ -86,14 +105,12 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
auto thread_reduce_func = [&](auto iC) { auto thread_reduce_func = [&](auto iC) {
AccDataType reduceSize = type_convert<AccDataType>(arg.n) * index_t offset_C = iC;
type_convert<AccDataType>(arg.h) * AccDataType mean = type_convert<AccDataType>(0.0f);
type_convert<AccDataType>(arg.w); AccDataType variance = type_convert<AccDataType>(0.0f);
index_t offset_C = iC; int32_t curr_count = 0;
AccDataType mean = type_convert<AccDataType>(0.0f);
AccDataType meansquare = type_convert<AccDataType>(0.0f); // compute mean, variance using welford method
// compute mean, meanquare, variance, invVariance
for(index_t iN = 0; iN < arg.n; iN++) for(index_t iN = 0; iN < arg.n; iN++)
{ {
index_t offset_N = iN * arg.h * arg.w * arg.c; index_t offset_N = iN * arg.h * arg.w * arg.c;
...@@ -106,40 +123,46 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch ...@@ -106,40 +123,46 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
auto offset = offset_N + offset_H + offset_W + offset_C; auto offset = offset_N + offset_H + offset_W + offset_C;
curr_count++;
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]); AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]);
mean += x; AccDataType delta = x - mean;
meansquare += x * x;
mean += delta / curr_count;
AccDataType delta2 = x - mean;
variance += delta * delta2;
}; };
} }
}; };
mean = mean / reduceSize; // actual variance
meansquare = meansquare / reduceSize; variance = variance / curr_count;
AccDataType variance = meansquare - mean * mean;
AccDataType invVariance = AccDataType invVariance =
type_convert<AccDataType>(1.0f) / type_convert<AccDataType>(1.0f) / ck::math::sqrt(arg.epsilon_ + variance);
std::sqrt(type_convert<AccDataType>(arg.epsilon_) + variance);
// save the mean/invVariance if required // save the mean/invVariance if required
if(arg.resultSave) if(arg.resultSave)
{ {
arg.resultSaveMean_[iC] = mean; arg.resultSaveMean_[iC] = type_convert<MeanVarDataType>(mean);
arg.resultSaveInvVariance_[iC] = invVariance; arg.resultSaveInvVariance_[iC] = type_convert<MeanVarDataType>(invVariance);
}; };
// update the moving average if required // update the moving average if required
if(arg.resultRunning) if(arg.resultRunning)
{ {
arg.resultRunningMean_[iC] = AccDataType oneMinusAverageFactor =
arg.resultRunningMean_[iC] * type_convert<AccDataType>(1.0) - arg.averageFactor_;
type_convert<AccDataType>(1.0 - arg.exponentialAverageFactor_) + arg.resultRunningMean_[iC] = type_convert<MeanVarDataType>(
mean * arg.exponentialAverageFactor_; type_convert<AccDataType>(arg.resultRunningMean_[iC]) *
arg.resultRunningVariance_[iC] = oneMinusAverageFactor +
arg.resultRunningVariance_[iC] * mean * arg.averageFactor_);
type_convert<AccDataType>(1.0 - arg.exponentialAverageFactor_) + arg.resultRunningVariance_[iC] = type_convert<MeanVarDataType>(
variance * arg.exponentialAverageFactor_; arg.resultRunningVariance_[iC] * oneMinusAverageFactor +
variance * arg.averageFactor_);
}; };
// Normalization // Normalization
...@@ -160,7 +183,7 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch ...@@ -160,7 +183,7 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
AccDataType norm_x = AccDataType norm_x =
arg.bnScale_[iC] * (x - mean) * invVariance + arg.bnBias_[iC]; arg.bnScale_[iC] * (x - mean) * invVariance + arg.bnBias_[iC];
arg.p_y_[offset] = type_convert<InOutDataType>(norm_x); arg.p_y_[offset] = type_convert<YDataType>(norm_x);
}; };
} }
}; };
...@@ -207,34 +230,42 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch ...@@ -207,34 +230,42 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
MakeArgumentPointer(const std::array<index_t, 4> xyLengths, MakeArgumentPointer(const std::array<index_t, 4> xyLengths,
const std::array<index_t, 4> xStrides, const std::array<index_t, 4> xStrides,
const std::array<index_t, 4> yStrides, const std::array<index_t, 4> yStrides,
const std::array<int, 3> reduceDims,
const std::array<index_t, 1> bnScaleBiasMeanVarLengths, const std::array<index_t, 1> bnScaleBiasMeanVarLengths,
const std::array<index_t, 1> bnScaleBiasMeanVarStrides, const std::array<index_t, 1> bnScaleStrides,
const std::array<index_t, 1> bnBiasStrides,
const std::array<index_t, 1> bnMeanVarStrides,
const void* p_x, const void* p_x,
const void* bnScale, const void* bnScale,
const void* bnBias, const void* bnBias,
void* p_y,
double exponentialAverageFactor,
void* resultRunningMean,
void* resultRunningVariance,
double epsilon, double epsilon,
const YElementwiseOp y_elementwise_op,
void* p_y,
void* resultSaveMean, void* resultSaveMean,
void* resultSaveInvVariance) override void* resultSaveInvVariance,
double averageFactor,
void* resultRunningMean,
void* resultRunningVariance) override
{ {
return std::make_unique<Argument>(xyLengths, return std::make_unique<Argument>(xyLengths,
xStrides, xStrides,
yStrides, yStrides,
reduceDims,
bnScaleBiasMeanVarLengths, bnScaleBiasMeanVarLengths,
bnScaleBiasMeanVarStrides, bnScaleStrides,
static_cast<const InOutDataType*>(p_x), bnBiasStrides,
static_cast<const AccDataType*>(bnScale), bnMeanVarStrides,
static_cast<const AccDataType*>(bnBias), static_cast<const XDataType*>(p_x),
static_cast<InOutDataType*>(p_y), static_cast<const ScaleDataType*>(bnScale),
exponentialAverageFactor, static_cast<const BiasDataType*>(bnBias),
static_cast<AccDataType*>(resultRunningMean),
static_cast<AccDataType*>(resultRunningVariance),
epsilon, epsilon,
static_cast<AccDataType*>(resultSaveMean), y_elementwise_op,
static_cast<AccDataType*>(resultSaveInvVariance)); static_cast<YDataType*>(p_y),
static_cast<MeanVarDataType*>(resultSaveMean),
static_cast<MeanVarDataType*>(resultSaveInvVariance),
averageFactor,
static_cast<MeanVarDataType*>(resultRunningMean),
static_cast<MeanVarDataType*>(resultRunningVariance));
}; };
std::unique_ptr<device::BaseInvoker> MakeInvokerPointer() override std::unique_ptr<device::BaseInvoker> MakeInvokerPointer() override
......
...@@ -14,7 +14,12 @@ namespace ck { ...@@ -14,7 +14,12 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace host { namespace host {
template <typename InOutDataType, typename AccDataType> template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType>
struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBatchNormInfer<4, 3> struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBatchNormInfer<4, 3>
{ {
struct Argument : public device::BaseArgument struct Argument : public device::BaseArgument
...@@ -23,14 +28,16 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat ...@@ -23,14 +28,16 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
const std::array<index_t, 4> xStrides, const std::array<index_t, 4> xStrides,
const std::array<index_t, 4> yStrides, const std::array<index_t, 4> yStrides,
const std::array<index_t, 1> bnScaleBiasMeanVarLengths, const std::array<index_t, 1> bnScaleBiasMeanVarLengths,
const std::array<index_t, 1> bnScaleBiasMeanVarStrides, const std::array<index_t, 1> bnScaleStrides,
const InOutDataType* p_x, const std::array<index_t, 1> bnBiasStrides,
const AccDataType* bnScale, const std::array<index_t, 1> bnMeanVarStrides,
const AccDataType* bnBias, const XDataType* p_x,
const ScaleDataType* bnScale,
const BiasDataType* bnBias,
double epsilon, double epsilon,
const AccDataType* estimatedMean, const MeanVarDataType* estimatedMean,
const AccDataType* estimatedVariance, const MeanVarDataType* estimatedVariance,
InOutDataType* p_y) YDataType* p_y)
: p_x_(p_x), : p_x_(p_x),
bnScale_(bnScale), bnScale_(bnScale),
bnBias_(bnBias), bnBias_(bnBias),
...@@ -39,32 +46,34 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat ...@@ -39,32 +46,34 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
estimatedVariance_(estimatedVariance), estimatedVariance_(estimatedVariance),
p_y_(p_y) p_y_(p_y)
{ {
(void)xStrides; ignore = xStrides;
(void)yStrides; ignore = yStrides;
(void)bnScaleBiasMeanVarStrides; ignore = bnScaleStrides;
ignore = bnBiasStrides;
ignore = bnMeanVarStrides;
if(xyLengths.size() != 4 || bnScaleBiasMeanVarLengths.size() != 1 || if(xyLengths.size() != 4 || bnScaleBiasMeanVarLengths.size() != 1 ||
bnScaleBiasMeanVarLengths[0] != xyLengths[3]) bnScaleBiasMeanVarLengths[0] != xyLengths[3])
throw std::runtime_error("Invalid tensor dimensions!"); throw std::runtime_error("Invalid tensor dimensions!");
n = xyLengths[0]; n_ = xyLengths[0];
h = xyLengths[1]; h_ = xyLengths[1];
w = xyLengths[2]; w_ = xyLengths[2];
c = xyLengths[3]; c_ = xyLengths[3];
} }
const InOutDataType* p_x_; const XDataType* p_x_;
const AccDataType* bnScale_; const ScaleDataType* bnScale_;
const AccDataType* bnBias_; const BiasDataType* bnBias_;
double epsilon_; double epsilon_;
const AccDataType* estimatedMean_; const MeanVarDataType* estimatedMean_;
const AccDataType* estimatedVariance_; const MeanVarDataType* estimatedVariance_;
InOutDataType* p_y_; YDataType* p_y_;
index_t n, h, w, c; index_t n_, h_, w_, c_;
}; };
struct Invoker : public device::BaseInvoker struct Invoker : public device::BaseInvoker
...@@ -81,15 +90,15 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat ...@@ -81,15 +90,15 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
std::sqrt(type_convert<AccDataType>(arg.epsilon_) + variance); std::sqrt(type_convert<AccDataType>(arg.epsilon_) + variance);
// Normalization // Normalization
for(index_t iN = 0; iN < arg.n; iN++) for(index_t iN = 0; iN < arg.n_; iN++)
{ {
index_t offset_N = iN * arg.h * arg.w * arg.c; index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_;
for(index_t iH = 0; iH < arg.h; iH++) for(index_t iH = 0; iH < arg.h_; iH++)
{ {
index_t offset_H = iH * arg.w * arg.c; index_t offset_H = iH * arg.w_ * arg.c_;
for(index_t iW = 0; iW < arg.w; iW++) for(index_t iW = 0; iW < arg.w_; iW++)
{ {
index_t offset_W = iW * arg.c; index_t offset_W = iW * arg.c_;
auto offset = offset_N + offset_H + offset_W + offset_C; auto offset = offset_N + offset_H + offset_W + offset_C;
...@@ -98,21 +107,21 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat ...@@ -98,21 +107,21 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
AccDataType norm_x = AccDataType norm_x =
arg.bnScale_[iC] * (x - mean) * invVariance + arg.bnBias_[iC]; arg.bnScale_[iC] * (x - mean) * invVariance + arg.bnBias_[iC];
arg.p_y_[offset] = type_convert<InOutDataType>(norm_x); arg.p_y_[offset] = type_convert<YDataType>(norm_x);
}; };
} }
}; };
}; };
std::size_t num_thread = std::thread::hardware_concurrency(); std::size_t num_thread = std::thread::hardware_concurrency();
std::size_t work_per_thread = (arg.c + num_thread - 1) / num_thread; std::size_t work_per_thread = (arg.c_ + num_thread - 1) / num_thread;
std::vector<joinable_thread> threads(num_thread); std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it) for(std::size_t it = 0; it < num_thread; ++it)
{ {
std::size_t ic_begin = it * work_per_thread; std::size_t ic_begin = it * work_per_thread;
std::size_t ic_end = std::min(static_cast<int>((it + 1) * work_per_thread), arg.c); std::size_t ic_end = std::min(static_cast<int>((it + 1) * work_per_thread), arg.c_);
auto f = [=] { auto f = [=] {
for(std::size_t ic = ic_begin; ic < ic_end; ++ic) for(std::size_t ic = ic_begin; ic < ic_end; ++ic)
...@@ -146,7 +155,9 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat ...@@ -146,7 +155,9 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
const std::array<index_t, 4> xStrides, const std::array<index_t, 4> xStrides,
const std::array<index_t, 4> yStrides, const std::array<index_t, 4> yStrides,
const std::array<index_t, 1> bnScaleBiasMeanVarLengths, const std::array<index_t, 1> bnScaleBiasMeanVarLengths,
const std::array<index_t, 1> bnScaleBiasMeanVarStrides, const std::array<index_t, 1> bnScaleStrides,
const std::array<index_t, 1> bnBiasStrides,
const std::array<index_t, 1> bnMeanVarStrides,
const void* p_x, const void* p_x,
const void* bnScale, const void* bnScale,
const void* bnBias, const void* bnBias,
...@@ -159,14 +170,16 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat ...@@ -159,14 +170,16 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
xStrides, xStrides,
yStrides, yStrides,
bnScaleBiasMeanVarLengths, bnScaleBiasMeanVarLengths,
bnScaleBiasMeanVarStrides, bnScaleStrides,
static_cast<const InOutDataType*>(p_x), bnBiasStrides,
static_cast<const AccDataType*>(bnScale), bnMeanVarStrides,
static_cast<const AccDataType*>(bnBias), static_cast<const XDataType*>(p_x),
static_cast<const ScaleDataType*>(bnScale),
static_cast<const BiasDataType*>(bnBias),
epsilon, epsilon,
static_cast<const AccDataType*>(estimatedMean), static_cast<const MeanVarDataType*>(estimatedMean),
static_cast<const AccDataType*>(estimatedVariance), static_cast<const MeanVarDataType*>(estimatedVariance),
static_cast<InOutDataType*>(p_y)); static_cast<YDataType*>(p_y));
}; };
std::unique_ptr<device::BaseInvoker> MakeInvokerPointer() override std::unique_ptr<device::BaseInvoker> MakeInvokerPointer() override
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment