Commit 6e3cf8b0 authored by Jing Zhang's avatar Jing Zhang
Browse files

merge develop

parents 4ad62d7f ba58a93f
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp" #include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
...@@ -185,12 +186,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -185,12 +186,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n, const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01, const Block2CTileMap& block_2_ctile_map)
index_t N01)
{ {
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value, static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time"); "wrong! K1 need to be known at compile-time");
...@@ -219,31 +220,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -219,31 +220,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return false; return false;
} }
// check M01, N01 if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
constexpr auto M1 = Number<MPerBlock>{}; {
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
return false; return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true; return true;
} }
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{ {
const index_t num_loop = K / (K0PerBlock * K1); const index_t num_loop = K / (K0PerBlock * K1);
...@@ -305,36 +290,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -305,36 +290,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0); return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
const auto N = c_grid_desc_m_n.GetLength(I1); c_grid_desc_m_n, M01, N01);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return cblockid_to_m0_n0_block_cluster_adaptor;
} }
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
...@@ -368,6 +325,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -368,6 +325,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const auto block_work_idx = const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0),
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR // HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp" #include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
...@@ -167,12 +168,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -167,12 +168,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc, CheckValidity(const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc, const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc, const CMNGridDesc& c_m_n_grid_desc,
index_t M01, const Block2CTileMap& block_2_ctile_map)
index_t N01)
{ {
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value, static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time"); "wrong! K1 need to be known at compile-time");
...@@ -196,31 +197,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -196,31 +197,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false; return false;
// check M01, N01 if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc))
constexpr auto M1 = Number<MPerBlock>{}; {
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
return false; return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true; return true;
} }
__host__ __device__ static constexpr index_t
CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc, index_t KBatch)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock) * KBatch;
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{ {
const bool has_main_k0_block_loop = K0 > K0PerBlock; const bool has_main_k0_block_loop = K0 > K0PerBlock;
...@@ -282,37 +267,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -282,37 +267,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch) const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
{ {
const auto M = c_m_n_grid_desc.GetLength(I0); return BlockToCTileMap_KSplit_M00_N00_M01_N01<MPerBlock, NPerBlock, CMNGridDesc>(
const auto N = c_m_n_grid_desc.GetLength(I1); c_m_n_grid_desc, M01, N01, KBatch);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_pass_through_transform(KBatch),
make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
const auto cblockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(KBatch, M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_kbatch_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
cblockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor);
return cblockid_to_kbatch_m0_n0_block_cluster_adaptor;
} }
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
...@@ -344,6 +300,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -344,6 +300,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
const auto block_work_idx = const auto block_work_idx =
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!c_block_cluster_adaptor.ValidCTileIndex(
make_tuple(block_work_idx[I1], block_work_idx[I2]),
make_tuple(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetLength(I0),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetLength(I1))))
{
return;
}
const index_t k_batch_id = block_work_idx[I0]; const index_t k_batch_id = block_work_idx[I0];
// HACK: this force m/n_block_data_idx_on_grid into SGPR // HACK: this force m/n_block_data_idx_on_grid into SGPR
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp" #include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp" #include "thread_group_tensor_slice_transfer_v6r1.hpp"
...@@ -174,12 +175,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -174,12 +175,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, CheckValidity(const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc, const CMNGridDesc& c_m_n_grid_desc,
index_t M01, const Block2CTileMap& block_2_ctile_map)
index_t N01)
{ {
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value, static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time"); "wrong! K1 need to be known at compile-time");
...@@ -203,31 +204,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -203,31 +204,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false; return false;
// check M01, N01 if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc))
constexpr auto M1 = Number<MPerBlock>{}; {
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
return false; return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true; return true;
} }
__host__ __device__ static constexpr index_t
CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc, index_t KBatch)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock) * KBatch;
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{ {
const bool has_main_k0_block_loop = K0 > K0PerBlock; const bool has_main_k0_block_loop = K0 > K0PerBlock;
...@@ -256,37 +241,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -256,37 +241,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch) const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
{ {
const auto M = c_m_n_grid_desc.GetLength(I0); return BlockToCTileMap_KSplit_M00_N00_M01_N01<MPerBlock, NPerBlock, CMNGridDesc>(
const auto N = c_m_n_grid_desc.GetLength(I1); c_m_n_grid_desc, M01, N01, KBatch);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_pass_through_transform(KBatch),
make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
const auto c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(KBatch, M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto c_blockid_to_kbatch_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor);
return c_blockid_to_kbatch_m0_n0_block_cluster_adaptor;
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -333,6 +289,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -333,6 +289,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const auto block_work_idx = const auto block_work_idx =
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!c_block_cluster_adaptor.ValidCTileIndex(
make_tuple(block_work_idx[I1], block_work_idx[I2]),
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
const index_t k_batch_id = block_work_idx[I0]; const index_t k_batch_id = block_work_idx[I0];
// HACK: this force m/n_block_data_idx_on_grid into SGPR // HACK: this force m/n_block_data_idx_on_grid into SGPR
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp" #include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp" #include "thread_group_tensor_slice_transfer_v6r1.hpp"
...@@ -223,12 +224,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -223,12 +224,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDesc_M_N& c_grid_desc_m_n, const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01, const Block2CTileMap& block_2_ctile_map)
index_t N01)
{ {
// static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value && // static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
// is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value, // is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
...@@ -256,31 +257,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -256,31 +257,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
return false; return false;
} }
// check M01, N01 if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
constexpr auto M1 = Number<MPerBlock>{}; {
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
return false; return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true; return true;
} }
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{ {
const index_t num_loop = K / KPerBlock; const index_t num_loop = K / KPerBlock;
...@@ -318,36 +303,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -318,36 +303,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0); return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
const auto N = c_grid_desc_m_n.GetLength(I1); c_grid_desc_m_n, M01, N01);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return cblockid_to_m0_n0_block_cluster_adaptor;
} }
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype( remove_cvref_t<decltype(
...@@ -385,6 +342,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -385,6 +342,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
const auto block_work_idx = const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.GetLength(I0),
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.GetLength(I3))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR // HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp" #include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r2.hpp" #include "thread_group_tensor_slice_transfer_v6r2.hpp"
...@@ -230,12 +231,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -230,12 +231,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n, const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01, const Block2CTileMap& block_2_ctile_map)
index_t N01)
{ {
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value, static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time"); "wrong! K1 need to be known at compile-time");
...@@ -264,31 +265,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -264,31 +265,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
return false; return false;
} }
// check M01, N01 if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
constexpr auto M1 = Number<MPerBlock>{}; {
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
return false; return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true; return true;
} }
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{ {
const index_t num_loop = K / (K0PerBlock * K1); const index_t num_loop = K / (K0PerBlock * K1);
...@@ -327,37 +312,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -327,37 +312,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0); return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
const auto N = c_grid_desc_m_n.GetLength(I1); c_grid_desc_m_n, M01, N01);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return cblockid_to_m0_n0_block_cluster_adaptor;
} }
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype( remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
...@@ -408,6 +366,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -408,6 +366,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
const auto block_work_idx = const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.GetLength(I0),
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.GetLength(I3))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR // HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp" #include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r3.hpp" #include "thread_group_tensor_slice_transfer_v6r3.hpp"
...@@ -237,12 +238,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -237,12 +238,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n, const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01, const Block2CTileMap& block_2_ctile_map)
index_t N01)
{ {
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value, static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time"); "wrong! K1 need to be known at compile-time");
...@@ -271,31 +272,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -271,31 +272,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
return false; return false;
} }
// check M01, N01 if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
constexpr auto M1 = Number<MPerBlock>{}; {
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
return false; return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true; return true;
} }
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{ {
const index_t num_loop = K / (K0PerBlock * K1); const index_t num_loop = K / (K0PerBlock * K1);
...@@ -334,36 +319,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -334,36 +319,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0); return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
const auto N = c_grid_desc_m_n.GetLength(I1); c_grid_desc_m_n, M01, N01);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return cblockid_to_m0_n0_block_cluster_adaptor;
} }
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype( remove_cvref_t<decltype(
...@@ -427,6 +384,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -427,6 +384,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
const auto block_work_idx = const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.GetLength(I0),
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.GetLength(I3))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR // HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
......
...@@ -258,6 +258,14 @@ __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32( ...@@ -258,6 +258,14 @@ __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32(
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32");
// buffer atomic-add fp32
__device__ double
llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int32x4_t rsrc, // dst_wave_buffer_resource
int voffset, // dst_thread_addr_offset
int soffset, // dst_wave_addr_offset
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
template <typename T, index_t N> template <typename T, index_t N>
__device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset, index_t src_thread_addr_offset,
...@@ -915,6 +923,71 @@ __device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::typ ...@@ -915,6 +923,71 @@ __device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::typ
} }
} }
template <typename T, index_t N>
__device__ void amd_buffer_atomic_max_impl(const typename vector_type<T, N>::type src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
{
static_assert((is_same<T, double>::value && (N == 1 || N == 2 || N == 4)),
"wrong! not implemented");
if constexpr(is_same<T, double>::value)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_atomic_max_fp64(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
vector_type<double, 2> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(double),
0);
}
else if constexpr(N == 4)
{
vector_type<double, 4> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(double),
0);
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<2>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 2 * sizeof(double),
0);
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<3>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 3 * sizeof(double),
0);
}
}
}
// buffer_load requires: // buffer_load requires:
// 1) p_src_wave must point to global memory space // 1) p_src_wave must point to global memory space
// 2) p_src_wave must be a wavewise pointer. // 2) p_src_wave must be a wavewise pointer.
...@@ -1046,4 +1119,39 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr ...@@ -1046,4 +1119,39 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
#endif #endif
} }
// buffer_atomic_max requires:
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t N>
__device__ void
amd_buffer_atomic_max(const typename vector_type_maker<T, N>::type::type src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const bool dst_thread_element_valid,
const index_t dst_element_space_size)
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space_size);
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
using vector_t = typename vector_type_maker<T, N>::type::type;
using scalar_t = typename scalar_type<vector_t>::type;
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff;
amd_buffer_atomic_max_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else
if(dst_thread_element_valid)
{
amd_buffer_atomic_max_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
#endif
}
} // namespace ck } // namespace ck
...@@ -32,7 +32,7 @@ ...@@ -32,7 +32,7 @@
#include "debug.hpp" #include "debug.hpp"
#include "amd_buffer_addressing.hpp" #include "amd_buffer_addressing.hpp"
#include "generic_memory_space_atomic_add.hpp" #include "generic_memory_space_atomic.hpp"
#include "get_id.hpp" #include "get_id.hpp"
#include "synchronization.hpp" #include "synchronization.hpp"
#include "amd_address_space.hpp" #include "amd_address_space.hpp"
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "enable_if.hpp" #include "enable_if.hpp"
#include "c_style_pointer_cast.hpp" #include "c_style_pointer_cast.hpp"
#include "amd_buffer_addressing.hpp" #include "amd_buffer_addressing.hpp"
#include "generic_memory_space_atomic_add.hpp" #include "generic_memory_space_atomic.hpp"
namespace ck { namespace ck {
...@@ -125,6 +125,10 @@ struct DynamicBuffer ...@@ -125,6 +125,10 @@ struct DynamicBuffer
{ {
this->template AtomicAdd<X>(i, is_valid_element, x); this->template AtomicAdd<X>(i, is_valid_element, x);
} }
else if constexpr(Op == InMemoryDataOperationEnum::AtomicMax)
{
this->template AtomicMax<X>(i, is_valid_element, x);
}
else if constexpr(Op == InMemoryDataOperationEnum::Add) else if constexpr(Op == InMemoryDataOperationEnum::Add)
{ {
auto tmp = this->template Get<X>(i, is_valid_element); auto tmp = this->template Get<X>(i, is_valid_element);
...@@ -326,6 +330,42 @@ struct DynamicBuffer ...@@ -326,6 +330,42 @@ struct DynamicBuffer
} }
} }
template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
__host__ __device__ void AtomicMax(index_t i, bool is_valid_element, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem");
#if CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64
using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
bool constexpr use_amd_buffer_addressing = is_same_v<remove_cvref_t<scalar_t>, double>;
#else
bool constexpr use_amd_buffer_addressing = false;
#endif
if constexpr(use_amd_buffer_addressing)
{
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_);
}
else if(is_valid_element)
{
atomic_max<X>(c_style_pointer_cast<X*>(&p_data_[i]), x);
}
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return false; } __host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
......
...@@ -3,6 +3,10 @@ ...@@ -3,6 +3,10 @@
namespace ck { namespace ck {
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to make the implementation of atomic_add explicit for
// each datatype.
template <typename X> template <typename X>
__device__ X atomic_add(X* p_dst, const X& x); __device__ X atomic_add(X* p_dst, const X& x);
...@@ -41,4 +45,53 @@ __device__ float2_t atomic_add<float2_t>(float2_t* p_dst, const float2_t& x) ...@@ -41,4 +45,53 @@ __device__ float2_t atomic_add<float2_t>(float2_t* p_dst, const float2_t& x)
return vy.template AsType<float2_t>()[I0]; return vy.template AsType<float2_t>()[I0];
} }
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to make the implementation of atomic_max explicit for
// each datatype.
template <typename X>
__device__ X atomic_max(X* p_dst, const X& x);
template <>
__device__ int32_t atomic_max<int32_t>(int32_t* p_dst, const int32_t& x)
{
return atomicMax(p_dst, x);
}
template <>
__device__ uint32_t atomic_max<uint32_t>(uint32_t* p_dst, const uint32_t& x)
{
return atomicMax(p_dst, x);
}
template <>
__device__ float atomic_max<float>(float* p_dst, const float& x)
{
return atomicMax(p_dst, x);
}
template <>
__device__ double atomic_max<double>(double* p_dst, const double& x)
{
return atomicMax(p_dst, x);
}
template <>
__device__ float2_t atomic_max<float2_t>(float2_t* p_dst, const float2_t& x)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
const vector_type<float, 2> vx{x};
vector_type<float, 2> vy{0};
vy.template AsType<float>()(I0) =
atomicMax(c_style_pointer_cast<float*>(p_dst), vx.template AsType<float>()[I0]);
vy.template AsType<float>()(I1) =
atomicMax(c_style_pointer_cast<float*>(p_dst) + 1, vx.template AsType<float>()[I1]);
return vy.template AsType<float2_t>()[I0];
}
} // namespace ck } // namespace ck
...@@ -11,10 +11,14 @@ __host__ __device__ constexpr index_t get_warp_size() ...@@ -11,10 +11,14 @@ __host__ __device__ constexpr index_t get_warp_size()
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; } __device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
__device__ index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x + threadIdx.x; }
__device__ index_t get_warp_local_1d_id() { return threadIdx.x / get_warp_size(); } __device__ index_t get_warp_local_1d_id() { return threadIdx.x / get_warp_size(); }
__device__ index_t get_block_1d_id() { return blockIdx.x; } __device__ index_t get_block_1d_id() { return blockIdx.x; }
__device__ index_t get_grid_size() { return gridDim.x; } __device__ index_t get_grid_size() { return gridDim.x; }
__device__ index_t get_block_size() { return blockDim.x; }
} // namespace ck } // namespace ck
...@@ -93,6 +93,13 @@ __host__ __device__ constexpr auto operator*(index_t a, const Tuple<Xs...>& x) ...@@ -93,6 +93,13 @@ __host__ __device__ constexpr auto operator*(index_t a, const Tuple<Xs...>& x)
return r; return r;
} }
// MultiIndex = MultiIndex * index_t
template <typename... Xs>
__host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, index_t a)
{
return a * x;
}
template <typename... Xs> template <typename... Xs>
__host__ __device__ void print_multi_index(const Tuple<Xs...>& x) __host__ __device__ void print_multi_index(const Tuple<Xs...>& x)
{ {
......
...@@ -29,6 +29,9 @@ using remove_cv_t = typename std::remove_cv<T>::type; ...@@ -29,6 +29,9 @@ using remove_cv_t = typename std::remove_cv<T>::type;
template <typename T> template <typename T>
using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>; using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
template <typename T>
using remove_pointer_t = typename std::remove_pointer<T>::type;
template <typename T> template <typename T>
inline constexpr bool is_pointer_v = std::is_pointer<T>::value; inline constexpr bool is_pointer_v = std::is_pointer<T>::value;
......
#pragma once
#include <memory>
#include <string>
#include "stream_config.hpp"
#include "config.hpp"
#include "device_base.hpp"
struct DeviceConvFwdPtr_t
{
using BaseArgument = ck::tensor_operation::device::BaseArgument;
using BaseInvoker = ck::tensor_operation::device::BaseInvoker;
struct DeviceConvFwdPtrImpl;
std::unique_ptr<DeviceConvFwdPtrImpl> pImpl;
DeviceConvFwdPtr_t();
~DeviceConvFwdPtr_t();
DeviceConvFwdPtr_t(DeviceConvFwdPtr_t&&);
DeviceConvFwdPtr_t(DeviceConvFwdPtrImpl&);
DeviceConvFwdPtr_t& operator=(DeviceConvFwdPtr_t&) = delete;
DeviceConvFwdPtr_t& operator=(const DeviceConvFwdPtr_t&) = delete;
std::unique_ptr<BaseArgument>
MakeArgumentPointer(void* in_ptr,
void* wei_ptr,
void* out_ptr,
size_t N,
size_t K,
size_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
const; // in,wei and out element ops are ignored for now since even if we change them, they
// cant be linked
std::unique_ptr<BaseInvoker>
MakeInvokerPointer() const; // requires including BaseInvoker headers
std::string GetTypeString();
bool IsSupportedArgument(const BaseArgument* arg_ptr);
};
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances);
#ifndef DEVICE_HPP #pragma once
#define DEVICE_HPP
#include <memory> #include <memory>
#include <functional> #include <functional>
#include <thread> #include <thread>
#include <chrono> #include <chrono>
#include "hip/hip_runtime.h" #include <hip/hip_runtime.h>
#include "hip/hip_fp16.h" #include <hip/hip_fp16.h>
#include "stream_config.hpp"
#include "ck/options.hpp"
template <typename T>
__global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size)
{
for(uint64_t i = threadIdx.x; i < buffer_element_size; i += blockDim.x)
{
p[i] = x;
}
}
inline void hip_check_error(hipError_t x)
{
if(x != hipSuccess)
{
std::ostringstream ss;
ss << "HIP runtime error: " << hipGetErrorString(x) << ". " << __FILE__ << ": " << __LINE__
<< "in function: " << __func__;
throw std::runtime_error(ss.str());
}
}
struct DeviceMem struct DeviceMem
{ {
...@@ -17,6 +39,16 @@ struct DeviceMem ...@@ -17,6 +39,16 @@ struct DeviceMem
void ToDevice(const void* p); void ToDevice(const void* p);
void FromDevice(void* p); void FromDevice(void* p);
void SetZero(); void SetZero();
template <typename T>
void SetValue(T x)
{
if(mMemSize % sizeof(T) != 0)
{
throw std::runtime_error("wrong! not entire DeviceMem will be set");
}
set_buffer_value<T><<<1, 1024>>>(static_cast<T*>(mpDeviceBuf), x, mMemSize / sizeof(T));
}
~DeviceMem(); ~DeviceMem();
void* mpDeviceBuf; void* mpDeviceBuf;
...@@ -36,49 +68,56 @@ struct KernelTimer ...@@ -36,49 +68,56 @@ struct KernelTimer
std::unique_ptr<KernelTimerImpl> impl; std::unique_ptr<KernelTimerImpl> impl;
}; };
using device_stream_t = hipStream_t;
template <typename... Args, typename F> template <typename... Args, typename F>
void launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) float launch_and_time_kernel(const StreamConfig& stream_config,
F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
Args... args)
{ {
hipStream_t stream_id = nullptr; #if CK_TIME_KERNEL
if(stream_config.time_kernel_)
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); {
} printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
__func__,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z);
template <typename... Args, typename F> const int nrepeat = 10;
float launch_and_time_kernel(
F kernel, int nrepeat, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
{
KernelTimer timer;
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", printf("Warm up 1 time\n");
__func__,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z);
printf("Warm up\n"); // warm up
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
hipStream_t stream_id = nullptr; printf("Start running %d times...\n", nrepeat);
// warm up KernelTimer timer;
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); timer.Start();
printf("Start running %d times...\n", nrepeat); for(int i = 0; i < nrepeat; ++i)
{
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
}
timer.Start(); timer.End();
for(int i = 0; i < nrepeat; ++i) return timer.GetElapsedTime() / nrepeat;
{
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
} }
else
{
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
timer.End(); return 0;
}
#else
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
return timer.GetElapsedTime() / nrepeat; return 0;
}
#endif #endif
}
...@@ -84,7 +84,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator ...@@ -84,7 +84,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator
return 0; return 0;
} }
float Run(const device::BaseArgument* p_arg, int) override float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg)); return Run(*dynamic_cast<const Argument*>(p_arg));
} }
......
#ifndef REFERENCE_CONV_WRW_HPP #pragma once
#define REFERENCE_CONV_WRW_HPP
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
...@@ -16,7 +15,9 @@ template <typename InDataType, ...@@ -16,7 +15,9 @@ template <typename InDataType,
typename OutDataType, typename OutDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation> typename OutElementwiseOperation,
ck::index_t NumDimSpatial = 2,
typename ck::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
struct ReferenceConvBwdWeight : public device::BaseOperator struct ReferenceConvBwdWeight : public device::BaseOperator
{ {
// Argument // Argument
...@@ -32,9 +33,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -32,9 +33,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) OutElementwiseOperation out_element_op)
: in_n_c_hi_wi_{in_n_c_hi_wi}, : input_{in_n_c_hi_wi},
wei_k_c_y_x_{wei_k_c_y_x}, weight_{wei_k_c_y_x},
out_n_k_ho_wo_{out_n_k_ho_wo}, output_{out_n_k_ho_wo},
conv_strides_{conv_filter_strides}, conv_strides_{conv_filter_strides},
conv_dilations_{conv_filter_dilations}, conv_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads}, in_left_pads_{input_left_pads},
...@@ -45,9 +46,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -45,9 +46,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
{ {
} }
const Tensor<InDataType>& in_n_c_hi_wi_; const Tensor<InDataType>& input_;
Tensor<WeiDataType>& wei_k_c_y_x_; Tensor<WeiDataType>& weight_;
const Tensor<OutDataType>& out_n_k_ho_wo_; const Tensor<OutDataType>& output_;
std::vector<index_t> conv_strides_; std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_; std::vector<index_t> conv_dilations_;
...@@ -66,62 +67,184 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -66,62 +67,184 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
constexpr auto I0 = Number<0>{}; if constexpr(NumDimSpatial == 1)
constexpr auto I1 = Number<1>{}; {
auto f_kcyx = [&](auto k, auto c, auto y, auto x) { constexpr auto I0 = Number<0>{};
float v_acc = 0; auto f_kcx = [&](auto k, auto c, auto x) {
for(std::size_t n = 0; n < arg.out_n_k_ho_wo_.mDesc.GetLengths()[0]; ++n) float v_acc = 0;
{ for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
for(std::size_t ho = 0; ho < arg.out_n_k_ho_wo_.mDesc.GetLengths()[2]; ++ho)
{ {
auto hi = ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[I0]) + for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[2]; ++wo)
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[I0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
for(std::size_t wo = 0; wo < arg.out_n_k_ho_wo_.mDesc.GetLengths()[3]; ++wo)
{ {
auto wi = auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I1]) + ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I0]) +
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[I1]) - ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[I0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]); ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
if(hi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(hi) < ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[2])
arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])
{ {
float v_out; float v_out;
float v_in; float v_in;
arg.out_element_op_( arg.out_element_op_(v_out,
v_out, ck::type_convert<float>(arg.output_(n, k, wo)));
ck::type_convert<float>(arg.out_n_k_ho_wo_(n, k, ho, wo))); arg.in_element_op_(v_in,
arg.in_element_op_( ck::type_convert<float>(arg.input_(n, c, wi)));
v_in, ck::type_convert<float>(arg.in_n_c_hi_wi_(n, c, hi, wi)));
v_acc += v_out * v_in; v_acc += v_out * v_in;
} }
} }
} }
} float v_wei;
float v_wei;
arg.wei_element_op_(v_wei, v_acc); arg.wei_element_op_(v_wei, v_acc);
arg.wei_k_c_y_x_(k, c, y, x) = ck::type_convert<OutDataType>(v_wei); arg.weight_(k, c, x) = ck::type_convert<WeiDataType>(v_wei);
}; };
make_ParallelTensorFunctor(f_kcyx, make_ParallelTensorFunctor(f_kcx,
arg.wei_k_c_y_x_.mDesc.GetLengths()[0], arg.weight_.mDesc.GetLengths()[0],
arg.wei_k_c_y_x_.mDesc.GetLengths()[1], arg.weight_.mDesc.GetLengths()[1],
arg.wei_k_c_y_x_.mDesc.GetLengths()[2], arg.weight_.mDesc.GetLengths()[2])(
arg.wei_k_c_y_x_.mDesc.GetLengths()[3])( std::thread::hardware_concurrency());
std::thread::hardware_concurrency());
return 0; return 0;
}
else if constexpr(NumDimSpatial == 2)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
auto f_kcyx = [&](auto k, auto c, auto y, auto x) {
float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
{
for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[2]; ++ho)
{
auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[I0]) +
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[I0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[3]; ++wo)
{
auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I1]) +
ck::type_convert<ck::long_index_t>(x *
arg.conv_dilations_[I1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]);
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[2] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[3])
{
float v_out;
float v_in;
arg.out_element_op_(
v_out, ck::type_convert<float>(arg.output_(n, k, ho, wo)));
arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(n, c, hi, wi)));
v_acc += v_out * v_in;
}
}
}
}
float v_wei;
arg.wei_element_op_(v_wei, v_acc);
arg.weight_(k, c, y, x) = ck::type_convert<WeiDataType>(v_wei);
};
make_ParallelTensorFunctor(f_kcyx,
arg.weight_.mDesc.GetLengths()[0],
arg.weight_.mDesc.GetLengths()[1],
arg.weight_.mDesc.GetLengths()[2],
arg.weight_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NumDimSpatial == 3)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
auto f_kczyx = [&](auto k, auto c, auto z, auto y, auto x) {
float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
{
for(std::size_t do_ = 0; do_ < arg.output_.mDesc.GetLengths()[2]; ++do_)
{
auto di =
ck::type_convert<ck::long_index_t>(do_ * arg.conv_strides_[I0]) +
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[I0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[3]; ++ho)
{
auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[I1]) +
ck::type_convert<ck::long_index_t>(y *
arg.conv_dilations_[I1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]);
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[4];
++wo)
{
auto wi =
ck::type_convert<ck::long_index_t>(wo *
arg.conv_strides_[I2]) +
ck::type_convert<ck::long_index_t>(
x * arg.conv_dilations_[I2]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I2]);
if(di >= 0 &&
ck::type_convert<std::size_t>(di) <
arg.input_.mDesc.GetLengths()[2] &&
hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[3] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[4])
{
float v_out;
float v_in;
arg.out_element_op_(v_out,
ck::type_convert<float>(
arg.output_(n, k, do_, ho, wo)));
arg.in_element_op_(
v_in,
ck::type_convert<float>(arg.input_(n, c, di, hi, wi)));
v_acc += v_out * v_in;
}
}
}
}
}
float v_wei;
arg.wei_element_op_(v_wei, v_acc);
arg.weight_(k, c, z, y, x) = ck::type_convert<WeiDataType>(v_wei);
};
make_ParallelTensorFunctor(f_kczyx,
arg.weight_.mDesc.GetLengths()[0],
arg.weight_.mDesc.GetLengths()[1],
arg.weight_.mDesc.GetLengths()[2],
arg.weight_.mDesc.GetLengths()[3],
arg.weight_.mDesc.GetLengths()[4])(
std::thread::hardware_concurrency());
return 0;
}
} }
float Run(const device::BaseArgument* p_arg, int) override float Run(const device::BaseArgument* p_arg,
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg)); return Run(*dynamic_cast<const Argument*>(p_arg));
} }
...@@ -181,4 +304,3 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -181,4 +304,3 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
} // namespace host } // namespace host
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -291,7 +291,8 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -291,7 +291,8 @@ struct ReferenceConvBwdData : public device::BaseOperator
} }
} }
float Run(const device::BaseArgument* p_arg, int) override float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg)); return Run(*dynamic_cast<const Argument*>(p_arg));
} }
......
#ifndef REFERENCE_CONV_FWD_HPP #pragma once
#define REFERENCE_CONV_FWD_HPP
#include <iostream> #include <iostream>
#include <type_traits> #include <type_traits>
#include <sstream> #include <sstream>
#include "stream_config.hpp"
#include "device_base.hpp" #include "device_base.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
...@@ -251,7 +252,8 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -251,7 +252,8 @@ struct ReferenceConvFwd : public device::BaseOperator
} }
} }
float Run(const device::BaseArgument* p_arg, int) override float Run(const device::BaseArgument* p_arg,
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg)); return Run(*dynamic_cast<const Argument*>(p_arg));
} }
...@@ -311,4 +313,3 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -311,4 +313,3 @@ struct ReferenceConvFwd : public device::BaseOperator
} // namespace host } // namespace host
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -124,7 +124,8 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator ...@@ -124,7 +124,8 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator
return 0; return 0;
} }
float Run(const device::BaseArgument* p_arg, int) override float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg)); return Run(*dynamic_cast<const Argument*>(p_arg));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment