"git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "40006a6a3623eb6e564604e2672a1ea4e44b4fe7"
Commit 740149fc authored by Chao Liu's avatar Chao Liu
Browse files

clean up

parent 40836ab9
...@@ -126,7 +126,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -126,7 +126,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
// blockwise copy // blockwise copy
// input: format is [C, Hi, Wi, N] // input: format is [C, Hi, Wi, N]
auto blockwise_in_copy = auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v2<BlockSize, BlockwiseGenericTensorSliceCopy_v1<BlockSize,
decltype(in_c_h_w_n_global_desc), decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_block_desc), decltype(in_c_h_w_n_block_desc),
decltype(in_c_h_w_n_block_desc.GetLengths()), decltype(in_c_h_w_n_block_desc.GetLengths()),
...@@ -142,9 +142,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -142,9 +142,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
{0, 0, 0, 0}); {0, 0, 0, 0});
// blockwise wei copy // blockwise wei copy
// format is [CPerBlock, KPerBlock] // format is [CPerBlock, X * KPerBlock]
const auto blockwise_wei_copy = const auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v2<BlockSize, BlockwiseGenericTensorSliceCopy_v1<BlockSize,
decltype(wei_c_k_global_desc), decltype(wei_c_k_global_desc),
decltype(wei_c_k_block_desc), decltype(wei_c_k_block_desc),
decltype(wei_c_k_block_desc.GetLengths()), decltype(wei_c_k_block_desc.GetLengths()),
...@@ -317,7 +317,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -317,7 +317,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin); n_block_data_begin + n_thread_data_begin);
ThreadwiseGenericTensorSliceCopy_v2r1<decltype(out_10d_thread_desc), #if 1
ThreadwiseGenericTensorSliceCopy_v1r2<decltype(out_10d_thread_desc),
decltype(out_10d_global_desc),
decltype(out_10d_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 10, 1>::type,
9,
OutThreadCopyDataPerAccess_N,
OutThreadCopyDataPerAccess_N>(
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global);
#elif 0
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
decltype(out_10d_global_desc), decltype(out_10d_global_desc),
decltype(out_10d_thread_desc.GetLengths()), decltype(out_10d_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 10, 1>::type, arithmetic_sequence_gen<0, 10, 1>::type,
...@@ -328,6 +339,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -328,6 +339,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
OutThreadCopyDataPerAccess_N>( OutThreadCopyDataPerAccess_N>(
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>()) make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global); .Run(p_out_thread, p_out_thread_on_global);
#endif
}).Else([&](auto fwd) { }).Else([&](auto fwd) {
static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0, GemmNPerThreadSubC % NPerThread == 0,
...@@ -375,7 +387,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -375,7 +387,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin); n_block_data_begin + n_thread_data_begin);
ThreadwiseGenericTensorSliceCopy_v2r1<decltype(out_10d_thread_desc), #if 1
ThreadwiseGenericTensorSliceCopy_v1r2<decltype(out_10d_thread_desc),
decltype(out_10d_global_desc),
decltype(out_10d_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 10, 1>::type,
9,
OutThreadCopyDataPerAccess_N,
OutThreadCopyDataPerAccess_N>(
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global);
#elif 0
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
decltype(out_10d_global_desc), decltype(out_10d_global_desc),
decltype(out_10d_thread_desc.GetLengths()), decltype(out_10d_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 10, 1>::type, arithmetic_sequence_gen<0, 10, 1>::type,
...@@ -386,6 +409,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -386,6 +409,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
OutThreadCopyDataPerAccess_N>( OutThreadCopyDataPerAccess_N>(
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>()) make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global); .Run(p_out_thread, p_out_thread_on_global);
#endif
}); });
} }
}; };
......
...@@ -7,10 +7,6 @@ ...@@ -7,10 +7,6 @@
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp" #include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_batched_gemm.hpp" #include "blockwise_batched_gemm.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "threadwise_tensor_slice_copy.hpp"
#include "threadwise_4d_tensor_op.hpp"
namespace ck { namespace ck {
...@@ -133,18 +129,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -133,18 +129,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed( constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed(
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{}); Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
#if 1
// blockwise copy // blockwise copy
// input: format is [C, Hi, Wi, N] // input: format is [C, Hi, Wi, N]
const auto blockwise_in_copy =
Blockwise4dTensorCopy3<BlockSize,
Float,
decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_block_desc),
decltype(in_c_h_w_n_block_desc.GetLengths()),
InBlockCopyClusterLengths_CHWN,
InBlockCopyDataPerAccess_N>{};
#else
auto blockwise_in_copy = auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockwiseGenericTensorSliceCopy_v1<BlockSize,
decltype(in_c_h_w_n_global_desc), decltype(in_c_h_w_n_global_desc),
...@@ -160,19 +146,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -160,19 +146,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
InBlockCopyDataPerAccess_N, InBlockCopyDataPerAccess_N,
InBlockCopyDataPerAccess_N>({0, 0, 0, 0}, InBlockCopyDataPerAccess_N>({0, 0, 0, 0},
{0, 0, 0, 0}); {0, 0, 0, 0});
#endif
#if 1
// blockwise wei copy // blockwise wei copy
// format is [CPerBlock, X * KPerBlock] // format is [CPerBlock, X * KPerBlock]
const auto blockwise_wei_copy =
Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(wei_c_k_global_desc),
decltype(wei_c_k_block_desc),
decltype(wei_c_k_block_desc.GetLengths()),
WeiBlockCopyDataPerAccess_K>({0, 0}, {0, 0});
#else
const auto blockwise_wei_copy = const auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockwiseGenericTensorSliceCopy_v1<BlockSize,
decltype(wei_c_k_global_desc), decltype(wei_c_k_global_desc),
...@@ -187,7 +163,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -187,7 +163,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
1, 1,
WeiBlockCopyDataPerAccess_K, WeiBlockCopyDataPerAccess_K,
WeiBlockCopyDataPerAccess_K>({0, 0}, {0, 0}); WeiBlockCopyDataPerAccess_K>({0, 0}, {0, 0});
#endif
// a series of blockwise batched GEMM // a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix // C_matrix += transpose(A_matrix) * B_matrix
...@@ -428,13 +403,16 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -428,13 +403,16 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
n_block_data_begin + n_thread_data_begin); n_block_data_begin + n_thread_data_begin);
#if 1 #if 1
threadwise_tensor_slice_copy(out_10d_thread_desc, ThreadwiseGenericTensorSliceCopy_v1r2<decltype(out_10d_thread_desc),
p_out_thread, decltype(out_10d_global_desc),
out_10d_global_desc, decltype(out_10d_thread_desc.GetLengths()),
p_out_thread_on_global, arithmetic_sequence_gen<0, 10, 1>::type,
out_10d_thread_desc.GetLengths(), 9,
Number<OutThreadCopyDataPerAccess_N>{}); OutThreadCopyDataPerAccess_N,
#else OutThreadCopyDataPerAccess_N>(
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global);
#elif 0
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc), ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
decltype(out_10d_global_desc), decltype(out_10d_global_desc),
decltype(out_10d_thread_desc.GetLengths()), decltype(out_10d_thread_desc.GetLengths()),
...@@ -495,13 +473,16 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -495,13 +473,16 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
n_block_data_begin + n_thread_data_begin); n_block_data_begin + n_thread_data_begin);
#if 1 #if 1
threadwise_tensor_slice_copy(out_10d_thread_desc, ThreadwiseGenericTensorSliceCopy_v1r2<decltype(out_10d_thread_desc),
p_out_thread, decltype(out_10d_global_desc),
out_10d_global_desc, decltype(out_10d_thread_desc.GetLengths()),
p_out_thread_on_global, arithmetic_sequence_gen<0, 10, 1>::type,
out_10d_thread_desc.GetLengths(), 9,
Number<OutThreadCopyDataPerAccess_N>{}); OutThreadCopyDataPerAccess_N,
#else OutThreadCopyDataPerAccess_N>(
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global);
#elif 0
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc), ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
decltype(out_10d_global_desc), decltype(out_10d_global_desc),
decltype(out_10d_thread_desc.GetLengths()), decltype(out_10d_thread_desc.GetLengths()),
......
...@@ -234,12 +234,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -234,12 +234,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS // b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in // c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register // register
constexpr auto a_e_k_block_mtx_desc = constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(wei_e_k_block_desc);
constexpr auto b_e_n1bn2_block_mtx_desc = constexpr auto b_e_n1bn2_block_mtx_desc =
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor( make_ConstantMatrixDescriptor(in_e_n1_b_n2_block_desc.Unfold(I1, I3));
in_e_n1_b_n2_block_desc.Unfold(I1, I3));
// sanity check // sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
......
...@@ -247,12 +247,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -247,12 +247,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS // b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in // c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register // register
constexpr auto a_e_k_block_mtx_desc = constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(wei_e_k_block_desc);
constexpr auto b_e_n1bn2_block_mtx_desc = constexpr auto b_e_n1bn2_block_mtx_desc =
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor( make_ConstantMatrixDescriptor(in_e_n1_b_n2_block_desc.Unfold(I1, I3));
in_e_n1_b_n2_block_desc.Unfold(I1, I3));
// sanity check // sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
......
...@@ -222,8 +222,7 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer ...@@ -222,8 +222,7 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS // b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in // c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register // register
constexpr auto a_e_k_block_mtx_desc = constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(wei_e_k_block_desc);
// this check is ad-hoc // this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment // TODO: need to properly implement tensor descriptor with multiple alignment
...@@ -233,8 +232,7 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer ...@@ -233,8 +232,7 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
"GemmDataPerReadB alignment requirement is not satisfied"); "GemmDataPerReadB alignment requirement is not satisfied");
constexpr auto b_e_n0ho0wo0bn2ho2wo2_block_mtx_desc = constexpr auto b_e_n0ho0wo0bn2ho2wo2_block_mtx_desc =
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor( make_ConstantMatrixDescriptor(in_e_n0_ho0_wo0_b_n2_ho2_wo2_block_desc.Unfold(I1, I7));
in_e_n0_ho0_wo0_b_n2_ho2_wo2_block_desc.Unfold(I1, I7));
// sanity check // sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
......
...@@ -228,8 +228,7 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer ...@@ -228,8 +228,7 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS // b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in // c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register // register
constexpr auto a_e_k_block_mtx_desc = constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(wei_e_k_block_desc);
// this check is ad-hoc // this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment // TODO: need to properly implement tensor descriptor with multiple alignment
...@@ -239,8 +238,7 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer ...@@ -239,8 +238,7 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
"GemmDataPerReadB alignment requirement is not satisfied"); "GemmDataPerReadB alignment requirement is not satisfied");
constexpr auto b_e_n1ho1wo1bn2ho2wo2_block_mtx_desc = constexpr auto b_e_n1ho1wo1bn2ho2wo2_block_mtx_desc =
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor( make_ConstantMatrixDescriptor(in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc.Unfold(I1, I7));
in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc.Unfold(I1, I7));
// sanity check // sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
......
...@@ -172,11 +172,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -172,11 +172,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
// b_mtx[EPerBlocl, BPerBlock] is in LDS // b_mtx[EPerBlocl, BPerBlock] is in LDS
// c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in // c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in
// register // register
constexpr auto a_e_k_block_mtx_desc = constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(wei_e_k_block_desc);
constexpr auto b_e_b_block_mtx_desc = constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc);
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(in_e_b_block_desc);
// sanity check // sanity check
static_assert( static_assert(
......
...@@ -182,11 +182,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -182,11 +182,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
// b_mtx[EPerBlocl, BPerBlock] is in LDS // b_mtx[EPerBlocl, BPerBlock] is in LDS
// c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in // c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in
// register // register
constexpr auto a_e_k_block_mtx_desc = constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(wei_e_k_block_desc);
constexpr auto b_e_b_block_mtx_desc = constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc);
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(in_e_b_block_desc);
// sanity check // sanity check
static_assert( static_assert(
......
...@@ -52,10 +52,10 @@ __host__ __device__ constexpr auto ...@@ -52,10 +52,10 @@ __host__ __device__ constexpr auto
return ConstantMatrixDescriptor<NRow, NCol, RowStride>{}; return ConstantMatrixDescriptor<NRow, NCol, RowStride>{};
} }
template <class TDesc> template <class... Ts>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto make_ConstantMatrixDescriptor(ConstantTensorDescriptor<Ts...>)
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(TDesc)
{ {
using TDesc = ConstantTensorDescriptor<Ts...>;
static_assert(TDesc::GetNumOfDimension() == 2, "wrong"); static_assert(TDesc::GetNumOfDimension() == 2, "wrong");
static_assert(TDesc::GetStrides()[1] == 1, "wrong"); static_assert(TDesc::GetStrides()[1] == 1, "wrong");
return ConstantMatrixDescriptor<TDesc::GetLengths()[0], return ConstantMatrixDescriptor<TDesc::GetLengths()[0],
......
...@@ -5,6 +5,10 @@ ...@@ -5,6 +5,10 @@
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "threadwise_gemm.hpp" #include "threadwise_gemm.hpp"
#ifndef CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
#define CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM 1
#endif
namespace ck { namespace ck {
template <index_t BlockSize, template <index_t BlockSize,
...@@ -97,24 +101,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -97,24 +101,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
mMyThreadOffsetB = c_thread_mtx_index.batch * BlockMatrixStrideB + mMyThreadOffsetB = c_thread_mtx_index.batch * BlockMatrixStrideB +
b_block_mtx.GetOffsetFromMultiIndex(0, c_thread_mtx_index.col); b_block_mtx.GetOffsetFromMultiIndex(0, c_thread_mtx_index.col);
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantMatrixDescriptor(BlockMatrixA{}, "a_block_mtx: ");
print_ConstantMatrixDescriptor(BlockMatrixB{}, "b_block_mtx: ");
print_ConstantMatrixDescriptor(ThreadMatrixC{}, "c_thread_mtx: ");
printf("%u %u, %u %u %u, %u %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
c_thread_mtx_index.batch,
c_thread_mtx_index.row,
c_thread_mtx_index.col,
mMyThreadOffsetA,
mMyThreadOffsetB);
}
#endif
} }
__device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const __device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const
...@@ -257,29 +243,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -257,29 +243,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
} }
} }
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
printf("a: %f %f %f %f %f %f %f %f, b: %f %f %f %f %f %f %f %f\n",
p_a_thread[0],
p_a_thread[1],
p_a_thread[2],
p_a_thread[3],
p_a_thread[4],
p_a_thread[5],
p_a_thread[6],
p_a_thread[7],
p_b_thread[0],
p_b_thread[1],
p_b_thread[2],
p_b_thread[3],
p_b_thread[4],
p_b_thread[5],
p_b_thread[6],
p_b_thread[7]);
}
#endif
threadwise_gemm(a_thread_mtx, threadwise_gemm(a_thread_mtx,
True, True,
p_a_thread, p_a_thread,
...@@ -311,10 +274,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -311,10 +274,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// thread A, B for GEMM // thread A, B for GEMM
// A is transposed, b is not // A is transposed, b is not
constexpr auto a_thread_mtx = constexpr auto a_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{}); make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<MPerThread>{});
constexpr auto b_thread_mtx = constexpr auto b_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{}); make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<NPerThread>{});
// thread A-sub, B-sub for copy // thread A-sub, B-sub for copy
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor( constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
...@@ -382,102 +345,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -382,102 +345,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
} }
template <class FloatA, class FloatB, class FloatC>
__device__ void Run_asm_v2(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC* __restrict__ p_c_thread) const
{
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t M = a_block_mtx.NCol();
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow(); // A is transposed
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
// A is transposed, b is not
constexpr auto a_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
constexpr auto b_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
// thread A-sub, B-sub for copy
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
// assertion for inline asm
static_assert(is_same<FloatA, float>{} && is_same<FloatB, float>{} &&
is_same<FloatC, float>{},
"Run_amd_asm only deal with float\n");
static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 &&
MPerThread == 8 && NPerThread == 8,
"Run_amd_asm cannot deal with this GEMM shape yet\n");
static_assert(DataPerReadA == 4 && DataPerReadB == 4, "Run_amd_asm only do float4 read\n");
static_assert(BlockMatrixStrideA == 0 && BatchPerThread == 1,
"Run_amd_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == "
"1 for now\n");
using Float4 = vector_type<float, 4>::MemoryType;
Float4* reg_a = (Float4*)(p_a_thread);
Float4* reg_b = (Float4*)(p_b_thread);
Float4* reg_c = (Float4*)(p_c_thread);
void* a_lds_loc = (void*)(p_a_block + mMyThreadOffsetA);
void* b_lds_loc = (void*)(p_b_block + mMyThreadOffsetB);
constexpr index_t a_lds_row_stride = sizeof(float) * a_block_mtx.RowStride();
constexpr index_t b_lds_row_stride = sizeof(float) * b_block_mtx.RowStride();
constexpr index_t a_lds_cluster_col_stride = sizeof(float) * MPerLevel1Cluster;
constexpr index_t b_lds_cluster_col_stride = sizeof(float) * NPerLevel1Cluster;
ds_read_b128(reg_a[0], a_lds_loc, 0);
ds_read_b128(reg_b[0], b_lds_loc, 0);
ds_read_b128(reg_b[1], b_lds_loc, b_lds_cluster_col_stride);
ds_read_b128(reg_a[1], a_lds_loc, a_lds_cluster_col_stride);
lgkmcnt(2);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
lgkmcnt(1);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
#pragma unroll
for(index_t k = 1; k < K; ++k)
{
ds_read_b128(reg_a[0], a_lds_loc, k * a_lds_row_stride);
lgkmcnt(1);
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
ds_read_b128(reg_b[0], b_lds_loc, k * b_lds_row_stride);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
ds_read_b128(reg_b[1], b_lds_loc, b_lds_cluster_col_stride + k * b_lds_row_stride);
ds_read_b128(reg_a[1], a_lds_loc, a_lds_cluster_col_stride + k * a_lds_row_stride);
lgkmcnt(2);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
lgkmcnt(1);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
}
lgkmcnt(0);
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
}
#endif #endif
template <class FloatA, class FloatB, class FloatC> template <class FloatA, class FloatB, class FloatC>
......
...@@ -7,10 +7,8 @@ ...@@ -7,10 +7,8 @@
#define CK_DEVICE_BACKEND_AMD 1 #define CK_DEVICE_BACKEND_AMD 1
#define CK_USE_AMD_INLINE_ASM 1 #define CK_USE_AMD_INLINE_ASM 1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 1 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
namespace ck { namespace ck {
......
...@@ -9,10 +9,8 @@ ...@@ -9,10 +9,8 @@
#define CK_DEVICE_BACKEND_NVIDIA 1 #define CK_DEVICE_BACKEND_NVIDIA 1
#define CK_USE_AMD_INLINE_ASM 0 #define CK_USE_AMD_INLINE_ASM 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
namespace ck { namespace ck {
......
...@@ -71,7 +71,7 @@ int main(int argc, char* argv[]) ...@@ -71,7 +71,7 @@ int main(int argc, char* argv[])
{ {
using namespace ck; using namespace ck;
#if 0 #if 1
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 1536; constexpr index_t C = 1536;
constexpr index_t HI = 8; constexpr index_t HI = 8;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment