Commit 68886f7d authored by raman jana's avatar raman jana
Browse files

merging with latest develop branch

parents a9ee2960 1677cf70
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp" #include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp" #include "thread_group_tensor_slice_transfer_v6r1.hpp"
...@@ -223,12 +224,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -223,12 +224,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDesc_M_N& c_grid_desc_m_n, const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01, const Block2CTileMap& block_2_ctile_map)
index_t N01)
{ {
// static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value && // static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
// is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value, // is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
...@@ -256,31 +257,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -256,31 +257,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
return false; return false;
} }
// check M01, N01 if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
constexpr auto M1 = Number<MPerBlock>{}; {
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
return false; return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true; return true;
} }
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{ {
const index_t num_loop = K / KPerBlock; const index_t num_loop = K / KPerBlock;
...@@ -315,39 +300,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -315,39 +300,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
} }
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0); return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
const auto N = c_grid_desc_m_n.GetLength(I1); c_grid_desc_m_n);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return cblockid_to_m0_n0_block_cluster_adaptor;
} }
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype( remove_cvref_t<decltype(
...@@ -357,7 +314,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -357,7 +314,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap> template <bool HasMainK0BlockLoop, typename Block2CTileMap>
__device__ static void __device__ static void
Run(const FloatAB* __restrict__ p_a_grid, Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
...@@ -385,6 +342,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -385,6 +342,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
const auto block_work_idx = const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.GetLength(I0),
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.GetLength(I3))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR // HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp" #include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r2.hpp" #include "thread_group_tensor_slice_transfer_v6r2.hpp"
...@@ -230,12 +231,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -230,12 +231,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n, const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01, const Block2CTileMap& block_2_ctile_map)
index_t N01)
{ {
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value, static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time"); "wrong! K1 need to be known at compile-time");
...@@ -264,31 +265,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -264,31 +265,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
return false; return false;
} }
// check M01, N01 if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
constexpr auto M1 = Number<MPerBlock>{}; {
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
return false; return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true; return true;
} }
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{ {
const index_t num_loop = K / (K0PerBlock * K1); const index_t num_loop = K / (K0PerBlock * K1);
...@@ -324,40 +309,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -324,40 +309,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
} }
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0); return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
const auto N = c_grid_desc_m_n.GetLength(I1); c_grid_desc_m_n);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return cblockid_to_m0_n0_block_cluster_adaptor;
} }
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype( remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
...@@ -408,6 +366,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -408,6 +366,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
const auto block_work_idx = const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.GetLength(I0),
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.GetLength(I3))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR // HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
...@@ -495,7 +464,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -495,7 +464,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
// sanity check // sanity check
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<ThisThreadBlock, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp" #include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r3.hpp" #include "thread_group_tensor_slice_transfer_v6r3.hpp"
...@@ -237,12 +238,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -237,12 +238,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n, const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01, const Block2CTileMap& block_2_ctile_map)
index_t N01)
{ {
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value, static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time"); "wrong! K1 need to be known at compile-time");
...@@ -271,31 +272,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -271,31 +272,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
return false; return false;
} }
// check M01, N01 if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
constexpr auto M1 = Number<MPerBlock>{}; {
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
return false; return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true; return true;
} }
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{ {
const index_t num_loop = K / (K0PerBlock * K1); const index_t num_loop = K / (K0PerBlock * K1);
...@@ -331,39 +316,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -331,39 +316,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
} }
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0); return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
const auto N = c_grid_desc_m_n.GetLength(I1); c_grid_desc_m_n);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return cblockid_to_m0_n0_block_cluster_adaptor;
} }
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype( remove_cvref_t<decltype(
...@@ -383,7 +340,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -383,7 +340,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap> template <bool HasMainKBlockLoop, typename Block2CTileMap>
__device__ static void __device__ static void
Run(const FloatAB* __restrict__ p_a_grid, Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
...@@ -427,6 +384,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -427,6 +384,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
const auto block_work_idx = const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.GetLength(I0),
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.GetLength(I3))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR // HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
...@@ -512,7 +480,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -512,7 +480,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
// sanity check // sanity check
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<ThisThreadBlock, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
......
#ifndef CK_THREADWISE_CONTRACTION_DLOPS_HPP #pragma once
#define CK_THREADWISE_CONTRACTION_DLOPS_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "math.hpp" #include "math.hpp"
...@@ -25,9 +23,9 @@ template <typename FloatA, ...@@ -25,9 +23,9 @@ template <typename FloatA,
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1 struct ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1
{ {
__device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1() __device__ constexpr ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1()
{ {
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
...@@ -124,9 +122,9 @@ template <typename FloatA, ...@@ -124,9 +122,9 @@ template <typename FloatA,
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 struct ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
{ {
__device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() __device__ constexpr ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
{ {
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
...@@ -220,4 +218,3 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_ ...@@ -220,4 +218,3 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_
}; };
} // namespace ck } // namespace ck
#endif
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP #pragma once
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
...@@ -609,4 +608,3 @@ struct ThreadwiseTensorSliceTransfer_v5r1 ...@@ -609,4 +608,3 @@ struct ThreadwiseTensorSliceTransfer_v5r1
}; };
} // namespace ck } // namespace ck
#endif
...@@ -25,6 +25,7 @@ enum struct MfmaInstr ...@@ -25,6 +25,7 @@ enum struct MfmaInstr
mfma_f32_16x16x8bf16, mfma_f32_16x16x8bf16,
mfma_i32_32x32x8i8, mfma_i32_32x32x8i8,
mfma_i32_16x16x16i8, mfma_i32_16x16x16i8,
mfma_f64_16x16x4f64
}; };
template <MfmaInstr instr> template <MfmaInstr instr>
...@@ -383,12 +384,40 @@ struct mfma_type<MfmaInstr::mfma_i32_16x16x16i8> ...@@ -383,12 +384,40 @@ struct mfma_type<MfmaInstr::mfma_i32_16x16x16i8>
} }
}; };
template <>
struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
{
static constexpr index_t group_size = 1;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 4; // group_size * num_groups_per_blk;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4; // wave_size / num_threads_per_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 1;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f64_16x16x4f64<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops> template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
struct MfmaSelector struct MfmaSelector
{ {
template <typename base_type_, index_t MPerXdlops_, index_t NPerXdlops_> template <typename base_type_, index_t MPerXdlops_, index_t NPerXdlops_>
static constexpr auto GetMfma(); static constexpr auto GetMfma();
template <>
static constexpr auto GetMfma<double, 16, 16>()
{
return MfmaInstr::mfma_f64_16x16x4f64;
}
template <> template <>
static constexpr auto GetMfma<float, 64, 64>() static constexpr auto GetMfma<float, 64, 64>()
{ {
...@@ -661,9 +690,10 @@ struct XdlopsGemm ...@@ -661,9 +690,10 @@ struct XdlopsGemm
template <class FloatA, class FloatB, class FloatC> template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{ {
static_assert(is_same<base_type, float>::value || is_same<base_type, half_t>::value || static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, bhalf_t>::value || is_same<base_type, int8_t>::value, is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
"base base_type must be float, half, bfloat16, and int8_t!"); is_same<base_type, int8_t>::value,
"base base_type must be double, float, half, bfloat16, and int8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
mfma_instr.template run<MPerXdlops, NPerXdlops>(p_a_wave[k], p_b_wave[k], p_c_thread); mfma_instr.template run<MPerXdlops, NPerXdlops>(p_a_wave[k], p_b_wave[k], p_c_thread);
......
...@@ -258,6 +258,14 @@ __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32( ...@@ -258,6 +258,14 @@ __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32(
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32");
// buffer atomic-add fp32
__device__ double
llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int32x4_t rsrc, // dst_wave_buffer_resource
int voffset, // dst_thread_addr_offset
int soffset, // dst_wave_addr_offset
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
template <typename T, index_t N> template <typename T, index_t N>
__device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset, index_t src_thread_addr_offset,
...@@ -915,6 +923,71 @@ __device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::typ ...@@ -915,6 +923,71 @@ __device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::typ
} }
} }
template <typename T, index_t N>
__device__ void amd_buffer_atomic_max_impl(const typename vector_type<T, N>::type src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
{
static_assert((is_same<T, double>::value && (N == 1 || N == 2 || N == 4)),
"wrong! not implemented");
if constexpr(is_same<T, double>::value)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_atomic_max_fp64(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
vector_type<double, 2> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(double),
0);
}
else if constexpr(N == 4)
{
vector_type<double, 4> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(double),
0);
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<2>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 2 * sizeof(double),
0);
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<3>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 3 * sizeof(double),
0);
}
}
}
// buffer_load requires: // buffer_load requires:
// 1) p_src_wave must point to global memory space // 1) p_src_wave must point to global memory space
// 2) p_src_wave must be a wavewise pointer. // 2) p_src_wave must be a wavewise pointer.
...@@ -1046,4 +1119,39 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr ...@@ -1046,4 +1119,39 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
#endif #endif
} }
// buffer_atomic_max requires:
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t N>
__device__ void
amd_buffer_atomic_max(const typename vector_type_maker<T, N>::type::type src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const bool dst_thread_element_valid,
const index_t dst_element_space_size)
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space_size);
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
using vector_t = typename vector_type_maker<T, N>::type::type;
using scalar_t = typename scalar_type<vector_t>::type;
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff;
amd_buffer_atomic_max_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else
if(dst_thread_element_valid)
{
amd_buffer_atomic_max_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
#endif
}
} // namespace ck } // namespace ck
...@@ -294,5 +294,24 @@ struct intrin_mfma_i32_16x16x16i8<16, 16> ...@@ -294,5 +294,24 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
} }
}; };
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f64_16x16x4f64;
template <>
struct intrin_mfma_f64_16x16x4f64<16, 16>
{
template <class FloatC>
__device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c)
{
#ifdef __gfx90a__
reg_c.template AsType<double4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64(
reg_a, reg_b, reg_c.template AsType<double4_t>()[Number<0>{}], 0, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
} // namespace ck } // namespace ck
#endif #endif
...@@ -32,7 +32,7 @@ ...@@ -32,7 +32,7 @@
#include "debug.hpp" #include "debug.hpp"
#include "amd_buffer_addressing.hpp" #include "amd_buffer_addressing.hpp"
#include "generic_memory_space_atomic_add.hpp" #include "generic_memory_space_atomic.hpp"
#include "get_id.hpp" #include "get_id.hpp"
#include "synchronization.hpp" #include "synchronization.hpp"
#include "amd_address_space.hpp" #include "amd_address_space.hpp"
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "enable_if.hpp" #include "enable_if.hpp"
#include "c_style_pointer_cast.hpp" #include "c_style_pointer_cast.hpp"
#include "amd_buffer_addressing.hpp" #include "amd_buffer_addressing.hpp"
#include "generic_memory_space_atomic_add.hpp" #include "generic_memory_space_atomic.hpp"
namespace ck { namespace ck {
...@@ -125,6 +125,10 @@ struct DynamicBuffer ...@@ -125,6 +125,10 @@ struct DynamicBuffer
{ {
this->template AtomicAdd<X>(i, is_valid_element, x); this->template AtomicAdd<X>(i, is_valid_element, x);
} }
else if constexpr(Op == InMemoryDataOperationEnum::AtomicMax)
{
this->template AtomicMax<X>(i, is_valid_element, x);
}
else if constexpr(Op == InMemoryDataOperationEnum::Add) else if constexpr(Op == InMemoryDataOperationEnum::Add)
{ {
auto tmp = this->template Get<X>(i, is_valid_element); auto tmp = this->template Get<X>(i, is_valid_element);
...@@ -326,6 +330,42 @@ struct DynamicBuffer ...@@ -326,6 +330,42 @@ struct DynamicBuffer
} }
} }
template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
__host__ __device__ void AtomicMax(index_t i, bool is_valid_element, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem");
#if CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64
using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
bool constexpr use_amd_buffer_addressing = is_same_v<remove_cvref_t<scalar_t>, double>;
#else
bool constexpr use_amd_buffer_addressing = false;
#endif
if constexpr(use_amd_buffer_addressing)
{
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_);
}
else if(is_valid_element)
{
atomic_max<X>(c_style_pointer_cast<X*>(&p_data_[i]), x);
}
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return false; } __host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
......
...@@ -3,6 +3,10 @@ ...@@ -3,6 +3,10 @@
namespace ck { namespace ck {
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to make the implementation of atomic_add explicit for
// each datatype.
template <typename X> template <typename X>
__device__ X atomic_add(X* p_dst, const X& x); __device__ X atomic_add(X* p_dst, const X& x);
...@@ -24,6 +28,12 @@ __device__ float atomic_add<float>(float* p_dst, const float& x) ...@@ -24,6 +28,12 @@ __device__ float atomic_add<float>(float* p_dst, const float& x)
return atomicAdd(p_dst, x); return atomicAdd(p_dst, x);
} }
template <>
__device__ double atomic_add<double>(double* p_dst, const double& x)
{
return atomicAdd(p_dst, x);
}
template <> template <>
__device__ float2_t atomic_add<float2_t>(float2_t* p_dst, const float2_t& x) __device__ float2_t atomic_add<float2_t>(float2_t* p_dst, const float2_t& x)
{ {
...@@ -41,4 +51,70 @@ __device__ float2_t atomic_add<float2_t>(float2_t* p_dst, const float2_t& x) ...@@ -41,4 +51,70 @@ __device__ float2_t atomic_add<float2_t>(float2_t* p_dst, const float2_t& x)
return vy.template AsType<float2_t>()[I0]; return vy.template AsType<float2_t>()[I0];
} }
template <>
__device__ double2_t atomic_add<double2_t>(double2_t* p_dst, const double2_t& x)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
const vector_type<double, 2> vx{x};
vector_type<double, 2> vy{0};
vy.template AsType<double>()(I0) =
atomicAdd(c_style_pointer_cast<double*>(p_dst), vx.template AsType<double>()[I0]);
vy.template AsType<double>()(I1) =
atomicAdd(c_style_pointer_cast<double*>(p_dst) + 1, vx.template AsType<double>()[I1]);
return vy.template AsType<double2_t>()[I0];
}
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to make the implementation of atomic_max explicit for
// each datatype.
template <typename X>
__device__ X atomic_max(X* p_dst, const X& x);
template <>
__device__ int32_t atomic_max<int32_t>(int32_t* p_dst, const int32_t& x)
{
return atomicMax(p_dst, x);
}
template <>
__device__ uint32_t atomic_max<uint32_t>(uint32_t* p_dst, const uint32_t& x)
{
return atomicMax(p_dst, x);
}
template <>
__device__ float atomic_max<float>(float* p_dst, const float& x)
{
return atomicMax(p_dst, x);
}
template <>
__device__ double atomic_max<double>(double* p_dst, const double& x)
{
return atomicMax(p_dst, x);
}
template <>
__device__ float2_t atomic_max<float2_t>(float2_t* p_dst, const float2_t& x)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
const vector_type<float, 2> vx{x};
vector_type<float, 2> vy{0};
vy.template AsType<float>()(I0) =
atomicMax(c_style_pointer_cast<float*>(p_dst), vx.template AsType<float>()[I0]);
vy.template AsType<float>()(I1) =
atomicMax(c_style_pointer_cast<float*>(p_dst) + 1, vx.template AsType<float>()[I1]);
return vy.template AsType<float2_t>()[I0];
}
} // namespace ck } // namespace ck
...@@ -11,10 +11,14 @@ __host__ __device__ constexpr index_t get_warp_size() ...@@ -11,10 +11,14 @@ __host__ __device__ constexpr index_t get_warp_size()
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; } __device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
__device__ index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x + threadIdx.x; }
__device__ index_t get_warp_local_1d_id() { return threadIdx.x / get_warp_size(); } __device__ index_t get_warp_local_1d_id() { return threadIdx.x / get_warp_size(); }
__device__ index_t get_block_1d_id() { return blockIdx.x; } __device__ index_t get_block_1d_id() { return blockIdx.x; }
__device__ index_t get_grid_size() { return gridDim.x; } __device__ index_t get_grid_size() { return gridDim.x; }
__device__ index_t get_block_size() { return blockDim.x; }
} // namespace ck } // namespace ck
#ifndef CK_INNER_PRODUCT_HPP #pragma once
#define CK_INNER_PRODUCT_HPP
#include "data_type.hpp" #include "data_type.hpp"
namespace ck { namespace ck {
...@@ -138,7 +136,7 @@ template <> ...@@ -138,7 +136,7 @@ template <>
__device__ void __device__ void
inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, int32_t& c) inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, int32_t& c)
{ {
#if defined(CK_USE_DOT4_I32_I8) #if defined(CK_USE_AMD_V_DOT4_I32_I8)
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM #if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
asm volatile("\n \ asm volatile("\n \
v_dot4_i32_i8 %0, %1, %2, %0\n \ v_dot4_i32_i8 %0, %1, %2, %0\n \
...@@ -202,4 +200,3 @@ inner_product<int8x16_t, int8x16_t, int32_t>(const int8x16_t& a, const int8x16_t ...@@ -202,4 +200,3 @@ inner_product<int8x16_t, int8x16_t, int32_t>(const int8x16_t& a, const int8x16_t
} }
} // namespace ck } // namespace ck
#endif
...@@ -3,11 +3,13 @@ ...@@ -3,11 +3,13 @@
#include <cmath> #include <cmath>
#include "data_type.hpp" #include "data_type.hpp"
#include "half.hpp" #include "type.hpp"
namespace ck { namespace ck {
namespace math { namespace math {
// math functions for the host, some are implemented by calling C++ std functions
static inline __host__ float abs(float x) { return std::abs(x); }; static inline __host__ float abs(float x) { return std::abs(x); };
static inline __host__ double abs(double x) { return std::abs(x); }; static inline __host__ double abs(double x) { return std::abs(x); };
...@@ -28,26 +30,26 @@ static inline __host__ int32_t abs(int32_t x) ...@@ -28,26 +30,26 @@ static inline __host__ int32_t abs(int32_t x)
static inline __host__ half_t abs(half_t x) static inline __host__ half_t abs(half_t x)
{ {
half_float::half xx = *reinterpret_cast<half_float::half*>(&x); uint16_t xx = ck::bit_cast<uint16_t>(x);
half_float::half abs_xx = half_float::abs(xx); uint16_t abs_xx = xx & 0x7fff;
half_t abs_x = *reinterpret_cast<half_t*>(&abs_xx); half_t abs_x = ck::bit_cast<half_t>(abs_xx);
return abs_x; return abs_x;
}; };
static inline __host__ float isnan(float x) { return std::isnan(x); }; static inline __host__ bool isnan(float x) { return std::isnan(x); };
static inline __host__ double isnan(double x) { return std::isnan(x); }; static inline __host__ bool isnan(double x) { return std::isnan(x); };
static inline __host__ int8_t isnan(int8_t x) static inline __host__ bool isnan(int8_t x)
{ {
(void)x; (void)x;
return false; return false;
}; };
static inline __host__ int32_t isnan(int32_t x) static inline __host__ bool isnan(int32_t x)
{ {
(void)x; (void)x;
return false; return false;
...@@ -55,11 +57,59 @@ static inline __host__ int32_t isnan(int32_t x) ...@@ -55,11 +57,59 @@ static inline __host__ int32_t isnan(int32_t x)
static inline __host__ bool isnan(half_t x) static inline __host__ bool isnan(half_t x)
{ {
half_float::half xx = *reinterpret_cast<half_float::half*>(&x); uint16_t xx = ck::bit_cast<uint16_t>(x);
return (xx & 0x7FFF) > 0x7C00;
};
static inline __host__ float sqrt(float x) { return std::sqrt(x); };
static inline __host__ double sqrt(double x) { return std::sqrt(x); };
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
static inline __device__ float abs(float x) { return ::abs(x); };
static inline __device__ double abs(double x) { return ::abs(x); };
static inline __device__ int8_t abs(int8_t x)
{
int8_t sgn = x >> (8 - 1);
return (x ^ sgn) - sgn;
};
static inline __device__ int32_t abs(int32_t x)
{
int32_t sgn = x >> (32 - 1);
return (x ^ sgn) - sgn;
};
static inline __device__ half_t abs(half_t x) { return ::__habs(x); };
static inline __device__ bool isnan(float x) { return ::isnan(x); };
static inline __device__ bool isnan(double x) { return ::isnan(x); };
static inline __device__ bool isnan(int8_t x)
{
(void)x;
return false;
};
return half_float::isnan(xx); static inline __device__ bool isnan(int32_t x)
{
(void)x;
return false;
}; };
static inline __device__ bool isnan(half_t x) { return ::__hisnan(x); };
static inline __device__ float sqrt(float x) { return ::sqrtf(x); };
static inline __device__ double sqrt(double x) { return ::sqrt(x); };
} // namespace math } // namespace math
} // namespace ck } // namespace ck
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#define CK_REDUCTION_FUNCTIONS_BINOP_HPP #define CK_REDUCTION_FUNCTIONS_BINOP_HPP
#include "data_type.hpp" #include "data_type.hpp"
#include "math_v2.hpp"
#include "reduction_common.hpp" #include "reduction_common.hpp"
#include "reduction_operator.hpp" #include "reduction_operator.hpp"
...@@ -34,18 +35,6 @@ ...@@ -34,18 +35,6 @@
namespace ck { namespace ck {
namespace detail { namespace detail {
template <typename T>
static inline __device__ bool is_nan(T x)
{
return (isnan(x));
};
template <>
inline __device__ bool is_nan<half_t>(half_t x)
{
return (__hisnan(x));
};
template <bool PropagateNan, typename ReduceOperation, typename AccDataType> template <bool PropagateNan, typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck; struct AccumulateWithNanCheck;
...@@ -53,7 +42,7 @@ template <typename ReduceOperation, typename AccDataType> ...@@ -53,7 +42,7 @@ template <typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType> struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
{ {
// cppcheck-suppress constParameter // cppcheck-suppress constParameter
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) __host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
{ {
ReduceOperation{}(accuVal, currVal); ReduceOperation{}(accuVal, currVal);
}; };
...@@ -62,9 +51,11 @@ struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType> ...@@ -62,9 +51,11 @@ struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
template <typename ReduceOperation, typename AccDataType> template <typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType> struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType>
{ {
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) __host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
{ {
if(is_nan(currVal)) using ck::math::isnan;
if(isnan(currVal))
{ {
accuVal = currVal; accuVal = currVal;
} }
...@@ -81,7 +72,7 @@ struct AccumulateWithIndexAndNanCheck; ...@@ -81,7 +72,7 @@ struct AccumulateWithIndexAndNanCheck;
template <typename ReduceOperation, typename AccDataType, typename IndexDataType> template <typename ReduceOperation, typename AccDataType, typename IndexDataType>
struct AccumulateWithIndexAndNanCheck<false, ReduceOperation, AccDataType, IndexDataType> struct AccumulateWithIndexAndNanCheck<false, ReduceOperation, AccDataType, IndexDataType>
{ {
__device__ static inline void __host__ __device__ static inline void
// cppcheck-suppress constParameter // cppcheck-suppress constParameter
Calculate(AccDataType& accuVal, Calculate(AccDataType& accuVal,
AccDataType currVal, AccDataType currVal,
...@@ -101,12 +92,14 @@ template <typename ReduceOperation, typename AccDataType, typename IndexDataType ...@@ -101,12 +92,14 @@ template <typename ReduceOperation, typename AccDataType, typename IndexDataType
struct AccumulateWithIndexAndNanCheck<true, ReduceOperation, AccDataType, IndexDataType> struct AccumulateWithIndexAndNanCheck<true, ReduceOperation, AccDataType, IndexDataType>
{ {
// The method is called when the ReduceOperation is indexable and the user asked for indices // The method is called when the ReduceOperation is indexable and the user asked for indices
__device__ static inline void Calculate(AccDataType& accuVal, __host__ __device__ static inline void Calculate(AccDataType& accuVal,
AccDataType currVal, AccDataType currVal,
IndexDataType& accuIndex, IndexDataType& accuIndex,
IndexDataType currIndex) IndexDataType currIndex)
{ {
if(is_nan(currVal)) using ck::math::isnan;
if(isnan(currVal))
{ {
accuVal = currVal; accuVal = currVal;
accuIndex = currIndex; accuIndex = currIndex;
......
...@@ -26,7 +26,8 @@ ...@@ -26,7 +26,8 @@
#ifndef CK_REDUCTION_OPERATOR_HPP #ifndef CK_REDUCTION_OPERATOR_HPP
#define CK_REDUCTION_OPERATOR_HPP #define CK_REDUCTION_OPERATOR_HPP
#include "common_header.hpp" #include "config.hpp"
#include "data_type.hpp"
namespace ck { namespace ck {
...@@ -35,18 +36,16 @@ namespace reduce { ...@@ -35,18 +36,16 @@ namespace reduce {
// Every binary operator used in reduction is represented by a templated functor class. Each functor // Every binary operator used in reduction is represented by a templated functor class. Each functor
// class must provide at least // class must provide at least
// three members: // three members:
// 1) GetReductionZeroVal() -- the interface to return the "identity element" for the binary // 1) GetIdentityValue() -- the interface to return the "identity element" for the binary
// operator, "identity element" is the unique // operator, "identity element" is the unique
// element in the algebraic space that doesn't affect the value of other elements // element in the algebraic space that doesn't affect the value of other elements
// when operated against them, and the concept is similar to zero vector in // when operated against them, and the concept is similar to zero vector in
// vector space // vector space
// (http://pages.cs.wisc.edu/~matthewb/pages/notes/pdf/linearalgebra/VectorSpaces.pdf). // (http://pages.cs.wisc.edu/~matthewb/pages/notes/pdf/linearalgebra/VectorSpaces.pdf).
// 2) indexable -- boolean value indicating whether indices of the operated elements could be // 2) IsCompatibleInMemoryDataOperation() -- return true if the reduction task corresponding to this
// recorded. Usually, Min/Max operator could // operator can use the InMemoryDataOperation to finalize, or else it return false 3) operator() --
// need to record the indices of elements. For operator like Add/Mul, no need to // the first argument of the operator must be both an input & output, and the corresponding variable
// record the indices. // usually stores
// 3) operator() -- the first argument of the operator must be both an input & output, and the
// corresponding variable usually stores
// the accumulated result of many operator() calls; the second argument is only an // the accumulated result of many operator() calls; the second argument is only an
// input. For indexable binary // input. For indexable binary
// operator, the second version of operator() has third argument (which is an // operator, the second version of operator() has third argument (which is an
...@@ -60,7 +59,14 @@ struct Add ...@@ -60,7 +59,14 @@ struct Add
{ {
using dataType = T; using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); }; __host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); };
__device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
return operation == InMemoryDataOperationEnum::AtomicAdd ||
operation == InMemoryDataOperationEnum::Set;
};
__host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a + b; } __host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a + b; }
}; };
...@@ -70,7 +76,13 @@ struct Mul ...@@ -70,7 +76,13 @@ struct Mul
{ {
using dataType = T; using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(1.0f); }; __host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(1.0f); };
__device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
return operation == InMemoryDataOperationEnum::Set;
};
__host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a * b; } __host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a * b; }
}; };
...@@ -80,11 +92,18 @@ struct Max ...@@ -80,11 +92,18 @@ struct Max
{ {
using dataType = T; using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() __host__ __device__ static constexpr T GetIdentityValue()
{ {
return NumericLimits<T>::Lowest(); return NumericLimits<T>::Lowest();
}; };
__device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
// ToChange: atomic_max to be added
return operation == InMemoryDataOperationEnum::Set;
};
__host__ __device__ inline constexpr void operator()(T& a, T b) const __host__ __device__ inline constexpr void operator()(T& a, T b) const
{ {
if(a < b) if(a < b)
...@@ -106,9 +125,13 @@ struct Min ...@@ -106,9 +125,13 @@ struct Min
{ {
using dataType = T; using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() __host__ __device__ static constexpr T GetIdentityValue() { return NumericLimits<T>::Max(); };
__device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
return NumericLimits<T>::Max(); // ToChange: atomic_min to be added
return operation == InMemoryDataOperationEnum::Set;
}; };
__host__ __device__ inline constexpr void operator()(T& a, T b) const __host__ __device__ inline constexpr void operator()(T& a, T b) const
...@@ -132,7 +155,14 @@ struct AMax ...@@ -132,7 +155,14 @@ struct AMax
{ {
using dataType = T; using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); }; __host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); };
__device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
// ToChange: atomic_max to be added
return operation == InMemoryDataOperationEnum::Set;
};
__host__ __device__ inline constexpr void operator()(T& a, T b) const __host__ __device__ inline constexpr void operator()(T& a, T b) const
{ {
...@@ -150,6 +180,17 @@ struct AMax ...@@ -150,6 +180,17 @@ struct AMax
} }
}; };
template <typename T>
T GetIdentityValueueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
T result = ck::type_convert<T>(0.0f);
if(operation == InMemoryDataOperationEnum::AtomicMax)
result = ck::NumericLimits<T>::Lowest();
return (result);
};
}; // end of namespace reduce }; // end of namespace reduce
} // end of namespace ck } // end of namespace ck
......
...@@ -36,6 +36,11 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N> ...@@ -36,6 +36,11 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N>
{ {
return base::operator()(i); return base::operator()(i);
} }
__host__ __device__ void Clear()
{
static_for<0, N, 1>{}([&](auto i) { operator()(i) = T{0}; });
}
}; };
// static buffer for vector // static buffer for vector
...@@ -146,9 +151,9 @@ struct StaticBufferTupleOfVector ...@@ -146,9 +151,9 @@ struct StaticBufferTupleOfVector
__host__ __device__ void Clear() __host__ __device__ void Clear()
{ {
const index_t numScalars = NumOfVector * ScalarPerVector; constexpr index_t NumScalars = NumOfVector * ScalarPerVector;
static_for<0, Number<numScalars>{}, 1>{}([&](auto i) { SetAsType(i, S{0}); }); static_for<0, NumScalars, 1>{}([&](auto i) { SetAsType(i, S{0}); });
} }
}; };
......
...@@ -93,6 +93,13 @@ __host__ __device__ constexpr auto operator*(index_t a, const Tuple<Xs...>& x) ...@@ -93,6 +93,13 @@ __host__ __device__ constexpr auto operator*(index_t a, const Tuple<Xs...>& x)
return r; return r;
} }
// MultiIndex = MultiIndex * index_t
template <typename... Xs>
__host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, index_t a)
{
return a * x;
}
template <typename... Xs> template <typename... Xs>
__host__ __device__ void print_multi_index(const Tuple<Xs...>& x) __host__ __device__ void print_multi_index(const Tuple<Xs...>& x)
{ {
......
...@@ -29,6 +29,9 @@ using remove_cv_t = typename std::remove_cv<T>::type; ...@@ -29,6 +29,9 @@ using remove_cv_t = typename std::remove_cv<T>::type;
template <typename T> template <typename T>
using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>; using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
template <typename T>
using remove_pointer_t = typename std::remove_pointer<T>::type;
template <typename T> template <typename T>
inline constexpr bool is_pointer_v = std::is_pointer<T>::value; inline constexpr bool is_pointer_v = std::is_pointer<T>::value;
......
#pragma once
#include <memory>
#include <string>
#include "stream_config.hpp"
#include "config.hpp"
#include "device_base.hpp"
struct DeviceConvFwdPtr_t
{
using BaseArgument = ck::tensor_operation::device::BaseArgument;
using BaseInvoker = ck::tensor_operation::device::BaseInvoker;
struct DeviceConvFwdPtrImpl;
std::unique_ptr<DeviceConvFwdPtrImpl> pImpl;
DeviceConvFwdPtr_t();
~DeviceConvFwdPtr_t();
DeviceConvFwdPtr_t(DeviceConvFwdPtr_t&&);
DeviceConvFwdPtr_t(DeviceConvFwdPtrImpl&);
DeviceConvFwdPtr_t& operator=(DeviceConvFwdPtr_t&) = delete;
DeviceConvFwdPtr_t& operator=(const DeviceConvFwdPtr_t&) = delete;
std::unique_ptr<BaseArgument>
MakeArgumentPointer(void* in_ptr,
void* wei_ptr,
void* out_ptr,
size_t N,
size_t K,
size_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
const; // in,wei and out element ops are ignored for now since even if we change them, they
// cant be linked
std::unique_ptr<BaseInvoker>
MakeInvokerPointer() const; // requires including BaseInvoker headers
std::string GetTypeString();
bool IsSupportedArgument(const BaseArgument* arg_ptr);
};
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment