Commit 5aed38d4 authored by Jing Zhang's avatar Jing Zhang
Browse files

clean

parent e5f7ded6
...@@ -11,129 +11,29 @@ ...@@ -11,129 +11,29 @@
namespace ck { namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dlops_v2(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatC* __restrict__ p_bias_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W c_blockid_to_k_n_h_w_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_bias_grid,
p_c_grid,
p_shared_block,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
c_blockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{});
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by CONSTANT void pointer
// CONSTANT is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dlops_v2(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatC* __restrict__ p_bias_grid,
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_e0_e1_k0_k1_e2_grid_desc,
const void CONSTANT* p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const void CONSTANT* p_c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const void CONSTANT* p_c_blockid_to_k_n_h_w_block_cluster_adaptor)
{
// first cast void CONSTANT void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_e0_e1_k0_k1_e2_grid_desc = *reinterpret_cast<const AGridDesc_E0_E1_K0_K1_E2*>(
cast_pointer_to_generic_address_space(p_a_e0_e1_k0_k1_e2_grid_desc));
const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc =
*reinterpret_cast<const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2*>(
cast_pointer_to_generic_address_space(p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc));
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc =
*reinterpret_cast<const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2*>(
cast_pointer_to_generic_address_space(p_c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc));
const auto c_blockid_to_k_n_h_w_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToBlockClusterAdaptor_K_N_H_W*>(
cast_pointer_to_generic_address_space(p_c_blockid_to_k_n_h_w_block_cluster_adaptor));
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_bias_grid,
p_c_grid,
p_shared_block,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
c_blockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{});
}
#endif
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AGridDesc_E0_E1_K_E2, typename AGlobalDesc,
typename BGridDesc_E0_E1_N_Ho_Wo_E2, typename BGlobalDesc,
typename CGridDesc_K_N_Ho_Wo, typename CGlobalDesc,
index_t E1_,
index_t E2_,
index_t K2_,
index_t KPerBlock, index_t KPerBlock,
index_t HoPerBlock, index_t HoPerBlock,
index_t WoPerBlock, index_t WoPerBlock,
index_t E1PerBlock, index_t EPerBlock,
index_t KPerThread, index_t KPerThread,
index_t HoPerThread, index_t HoPerThread,
index_t WoPerThread, index_t WoPerThread,
index_t EPerThread, index_t EPerThread,
typename ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2, typename ABlockTransferThreadSliceLengths_E_K,
typename ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2, typename ABlockTransferThreadClusterLengths_E_K,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_E2, index_t ABlockTransferDstScalarPerVector_K,
bool AThreadTransferSrcResetCoordinateAfterRun, bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferSrcAccessOrder, typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcVectorDim,
...@@ -146,497 +46,294 @@ template <index_t BlockSize, ...@@ -146,497 +46,294 @@ template <index_t BlockSize,
typename BGlobalStepHacks, typename BGlobalStepHacks,
typename CGlobalStepHacks, typename CGlobalStepHacks,
typename AGlobalMoveSliceWindowStepHacks, typename AGlobalMoveSliceWindowStepHacks,
typename BGlobalMoveSliceWindowStepHacks, typename BGlobalMoveSliceWindowStepHacks>
ActivTypeEnum_t activ_type = ActivTypeEnum_t::None> struct GridwiseGemmDlops_km_kn_mn_v3
struct GridwiseGemmDlops_km_kn_mn_v2
{ {
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto E1 = Number<E1_>{};
static constexpr auto E2 = Number<E2_>{};
static constexpr auto K2 = Number<K2_>{};
static constexpr auto NPerBlock = I1;
static constexpr FloatC alpha = 0.3;
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
constexpr auto max_lds_align = Number<ABlockTransferDstScalarPerVector_E2>{}; constexpr auto E = EPerBlock * 3 * 3;
constexpr auto max_lds_align =
math::lcm(Number<ABlockTransferDstScalarPerVector_K>{}, Number<KPerBlock>{});
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_e0_e1_k1_e2_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned(
make_tuple(I1, Number<E1>{}, Number<KPerBlock>{}, Number<E2>{}), max_lds_align); make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = math::integer_least_multiple( constexpr auto a_block_space_size =
a_e0_e1_k1_e2_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_e_k_desc.GetElementSpaceSize(), max_lds_align);
return a_block_space_size * sizeof(FloatAB); return a_block_space_size * sizeof(FloatAB);
} }
__host__ __device__ static constexpr index_t template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
CalculateGridSize(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc) __device__ void Run(const AGlobalDesc& a_e_k_global_desc,
{ const FloatAB* __restrict__ p_a_global,
const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0); const BGlobalDesc& b_e_n_ho_wo_global_desc,
const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1);
const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2);
const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3);
const auto K0 = K / KPerBlock;
const auto N0 = N / NPerBlock;
const auto H0 = Ho / HoPerBlock;
const auto W0 = Wo / WoPerBlock;
const index_t grid_size = K0 * N0 * H0 * W0;
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainE0BlockLoop(const index_t E0)
{
const bool has_main_e0_block_loop = E0 > 1;
return has_main_e0_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasMainE1BlockLoop()
{
const bool has_main_e1_block_loop = (E1 + E1PerBlock) / (2 * E1PerBlock) > 1;
return has_main_e1_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasDoubleTailE1BlockLoop()
{
const bool has_double_tail_e1_block_loop = (E1 / E1PerBlock) % 2 == 0;
return has_double_tail_e1_block_loop;
}
__host__ __device__ static constexpr auto
MakeAE0E1K0K1E2GridDescriptor(const AGridDesc_E0_E1_K_E2& a_e0_e1_k_e2_grid_desc)
{
const auto E0 = a_e0_e1_k_e2_grid_desc.GetLength(I0);
const auto K = a_e0_e1_k_e2_grid_desc.GetLength(I2);
const auto K1 = Number<KPerBlock>{};
const auto K0 = K / K1;
const auto a_e0_e1_k0_k1_e2_grid_desc = transform_tensor_descriptor(
a_e0_e1_k_e2_grid_desc,
make_tuple(make_pass_through_transform(E0),
make_pass_through_transform(E1),
make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
return a_e0_e1_k0_k1_e2_grid_desc;
}
__host__ __device__ static constexpr auto
MakeCK0K1NH0H1H2W0W1W2GridDescriptor(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc)
{
const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0);
const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1);
const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2);
const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3);
const auto K1 = Number<KPerBlock>{};
const auto K0 = K / K1;
const auto H2 = Number<HoPerThread>{};
const auto H1 = Number<HoPerBlock / HoPerThread>{};
const auto H0 = Ho / (H1 * H2);
const auto W2 = Number<WoPerThread>{};
const auto W1 = Number<WoPerBlock / WoPerThread>{};
const auto W0 = Wo / (W1 * W2);
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc = transform_tensor_descriptor(
c_k_n_ho_wo_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_unmerge_transform(make_tuple(H0, H1, H2)),
make_unmerge_transform(make_tuple(W0, W1, W2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{}));
return c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc;
}
__host__ __device__ static constexpr auto MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(
const BGridDesc_E0_E1_N_Ho_Wo_E2& b_e0_e1_n_ho_wo_e2_grid_desc)
{
const auto E0 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I0);
// const auto E1 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I1);
const auto N = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I2);
const auto Ho = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I3);
const auto Wo = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I4);
// const auto E2 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I5);
const auto H2 = Number<HoPerThread>{};
const auto H1 = Number<HoPerBlock / HoPerThread>{};
const auto H0 = Ho / (H1 * H2);
const auto W2 = Number<WoPerThread>{};
const auto W1 = Number<WoPerBlock / WoPerThread>{};
const auto W0 = Wo / (W1 * W2);
const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc =
transform_tensor_descriptor(b_e0_e1_n_ho_wo_e2_grid_desc,
make_tuple(make_pass_through_transform(E0),
make_pass_through_transform(E1),
make_pass_through_transform(N),
make_unmerge_transform(make_tuple(H0, H1, H2)),
make_unmerge_transform(make_tuple(W0, W1, W2)),
make_pass_through_transform(E2)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3, 4, 5>{},
Sequence<6, 7, 8>{},
Sequence<9>{}));
return b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc;
}
__host__ __device__ static constexpr auto
MakeCBlockIdToKNHoWoBlockClusterAdaptor(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc)
{
const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0);
const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1);
const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2);
const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3);
const auto K0 = Number<K / KPerBlock>{};
const auto N0 = Number<N / NPerBlock>{};
const auto H0 = Number<Ho / HoPerBlock>{};
const auto W0 = Number<Wo / WoPerBlock>{};
const auto c_blockid_to_k_n_ho_wo_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(K0, N0, H0, W0))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
return c_blockid_to_k_n_ho_wo_block_cluster_adaptor;
}
using AGridDesc_E0_E1_K0_K1_E2 =
decltype(MakeAE0E1K0K1E2GridDescriptor(AGridDesc_E0_E1_K_E2{}));
using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 =
decltype(MakeCK0K1NH0H1H2W0W1W2GridDescriptor(CGridDesc_K_N_Ho_Wo{}));
using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
decltype(MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(BGridDesc_E0_E1_N_Ho_Wo_E2{}));
using CBlockIdToBlockClusterAdaptor_K_N_H_W =
decltype(MakeCBlockIdToKNHoWoBlockClusterAdaptor(CGridDesc_K_N_Ho_Wo{}));
__host__ __device__ static constexpr auto MakeBiasK0K1GridDescriptor(
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc)
{
const auto K0 = c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetLength(I0);
const auto K1 = c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetLength(I1);
return make_naive_tensor_descriptor_packed(make_tuple(K0, K1));
}
template <bool HasMainE0BlockLoop_>
__device__ static void
Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const FloatC* __restrict__ p_bias_global, const CGlobalDesc& c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc_, integral_constant<bool, HasMainKBlockLoop>,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_, integral_constant<bool, HasDoubleTailKBlockLoop>) const
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& c_blockid_to_k_n_h_w_block_cluster_adaptor_,
integral_constant<bool, HasMainE0BlockLoop_>)
{ {
constexpr auto I0 = Number<0>{};
constexpr auto a_e0_e1_k0_k1_e2_grid_desc = AGridDesc_E0_E1_K0_K1_E2{}; constexpr auto I1 = Number<1>{};
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc = constexpr auto I2 = Number<2>{};
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2{}; constexpr auto I3 = Number<3>{};
constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc =
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2{};
constexpr auto c_blockid_to_k_n_h_w_block_cluster_adaptor =
CBlockIdToBlockClusterAdaptor_K_N_H_W{};
constexpr auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
constexpr bool HasMainE0BlockLoop =
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetLength(I0) > 1;
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); p_a_global, a_e_k_global_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize()); p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize());
constexpr auto E = EPerBlock * 3 * 3;
constexpr auto HasMainE1BlockLoop = CalculateHasMainE1BlockLoop(); // const auto E = a_e_k_global_desc.GetLength(I0);
constexpr auto HasDoubleTailE1BlockLoop = CalculateHasDoubleTailE1BlockLoop(); const auto K = a_e_k_global_desc.GetLength(I1);
// const auto Ho = b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetLength(I3); const auto N = b_e_n_ho_wo_global_desc.GetLength(I1);
// const auto Wo = b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetLength(I4); const auto Ho = b_e_n_ho_wo_global_desc.GetLength(I2);
const auto Wo = b_e_n_ho_wo_global_desc.GetLength(I3);
const auto c_k_n_h_w_block_cluster_idx = // divide block work by [M, N]
c_blockid_to_k_n_h_w_block_cluster_adaptor.CalculateBottomIndex( #if 0
make_multi_index(get_block_1d_id())); const auto ho_block_work_num = Ho / Number<HoPerBlock>{};
const auto wo_block_work_num = Wo / Number<WoPerBlock>{};
const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num;
const index_t k_block_work_id = get_block_1d_id() / hwo_block_work_num;
const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num;
const index_t ho_block_work_id = hwo_block_work_id / wo_block_work_num;
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
#else
// Hack: this force result into SGPR
const index_t ho_block_work_num = __builtin_amdgcn_readfirstlane(Ho / HoPerBlock);
const index_t wo_block_work_num = __builtin_amdgcn_readfirstlane(Wo / WoPerBlock);
const index_t hwo_block_work_num = ho_block_work_num * wo_block_work_num;
const index_t k_block_work_id = const index_t k_block_work_id =
__builtin_amdgcn_readfirstlane(c_k_n_h_w_block_cluster_idx[I0]); __builtin_amdgcn_readfirstlane(get_block_1d_id() / hwo_block_work_num);
const index_t n_block_work_id = const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num;
__builtin_amdgcn_readfirstlane(c_k_n_h_w_block_cluster_idx[I1]);
const index_t ho_block_work_id = const index_t ho_block_work_id =
__builtin_amdgcn_readfirstlane(c_k_n_h_w_block_cluster_idx[I2]); __builtin_amdgcn_readfirstlane(hwo_block_work_id / wo_block_work_num);
const index_t wo_block_work_id = const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
__builtin_amdgcn_readfirstlane(c_k_n_h_w_block_cluster_idx[I3]); #endif
// lds max alignment
constexpr auto max_lds_align =
math::lcm(Number<ABlockTransferDstScalarPerVector_K>{}, Number<KPerBlock>{});
constexpr auto max_lds_align = Number<ABlockTransferDstScalarPerVector_E2>{}; // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_e_k_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align);
constexpr auto a_e1_k1_e2_block_gemm_desc = make_naive_tensor_descriptor_aligned( constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<E1PerBlock>{}, Number<KPerBlock>{}, Number<E2>{}), max_lds_align); make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
constexpr auto b_e1_n_h_w_e2_block_gemm_desc = // B matrix in LDS memory, dst of blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(Number<E1PerBlock>{}, // be careful of LDS alignment
I1, constexpr auto b_e_n_ho_wo_block_desc = make_naive_tensor_descriptor_packed(make_tuple(
Number<HoPerBlock>{}, Number<EPerBlock>{}, Number<1>{}, Number<HoPerBlock>{}, Number<WoPerBlock>{}));
Number<WoPerBlock>{},
Number<E2>{}));
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = make_naive_tensor_descriptor_packed( // c_thread_mtx definition: this is a mess
make_tuple(Number<KPerThread>{}, I1, Number<HoPerThread>{}, Number<WoPerThread>{})); // TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize, BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_e1_k1_e2_block_gemm_desc), decltype(a_e_k_block_desc),
decltype(b_e1_n_h_w_e2_block_gemm_desc), decltype(b_e_n_ho_wo_block_desc),
decltype(c_k1_n_h2_w2_thread_gemm_desc), decltype(c_k_n_ho_wo_thread_desc),
KPerThread,
HoPerThread,
WoPerThread,
EPerThread, EPerThread,
K2>{}; ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K>{};
auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
auto c_thread_mtx_index = const auto k_thread_id = c_thread_mtx_index.k;
blockwise_gemm.GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id()); const auto ho_thread_id = c_thread_mtx_index.h;
const auto wo_thread_id = c_thread_mtx_index.w;
const auto k_thread_id = c_thread_mtx_index[I0]; const index_t k_block_data_on_global = k_block_work_id * KPerBlock;
const auto ho_thread_id = c_thread_mtx_index[I2]; const index_t ho_block_data_on_global = ho_block_work_id * HoPerBlock;
const auto wo_thread_id = c_thread_mtx_index[I3]; const index_t wo_block_data_on_global = wo_block_work_id * WoPerBlock;
constexpr auto a_e0_e1_k0_k1_e2_block_copy_desc = make_naive_tensor_descriptor_aligned( const index_t ho_thread_data_on_global =
make_tuple(Number<I1>{}, Number<E1>{}, I1, Number<KPerBlock>{}, Number<E2>{}), ho_block_data_on_global + ho_thread_id * HoPerThread;
max_lds_align); const index_t wo_thread_data_on_global =
wo_block_data_on_global + wo_thread_id * WoPerThread;
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<I1, E1, I1, KPerBlock, E2>, Sequence<E, KPerBlock>,
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2, ABlockTransferThreadSliceLengths_E_K,
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2, ABlockTransferThreadClusterLengths_E_K,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_e0_e1_k0_k1_e2_grid_desc), decltype(a_e_k_global_desc),
decltype(a_e0_e1_k0_k1_e2_block_copy_desc), decltype(a_e_k_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3, 4>, // ABlockTransferDstAccessOrder Sequence<0, 1>,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
4, // ABlockTransferDstVectorDim 1,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_E2, ABlockTransferDstScalarPerVector_K,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
false>(a_e0_e1_k0_k1_e2_grid_desc, true>(a_e_k_global_desc,
make_multi_index(0, 0, k_block_work_id, 0, 0), make_multi_index(0, k_block_data_on_global),
a_e0_e1_k0_k1_e2_block_copy_desc, a_e_k_desc,
make_multi_index(0, 0, 0, 0, 0)); make_multi_index(0, 0));
constexpr auto a_block_slice_copy_step = make_multi_index(I1, 0, 0, 0, 0); constexpr auto b_e_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc =
make_naive_tensor_descriptor_packed(make_tuple(I1, auto b_threadwise_transfer =
Number<E1PerBlock>{}, ThreadwiseTensorSliceTransfer_v2<FloatAB,
I1,
I1,
I1,
Number<HoPerThread>{},
I1,
I1,
Number<WoPerThread>{},
Number<E2>{}));
auto b_threadwise_transfer = ThreadwiseTensorSliceTransfer_v2<
FloatAB,
FloatAB, FloatAB,
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc), decltype(b_e_n_ho_wo_global_desc),
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc), decltype(b_e_n_ho_wo_thread_desc),
Sequence<I1, E1PerBlock, I1, I1, I1, HoPerThread, I1, I1, WoPerThread, E2>, Sequence<EPerBlock, 1, HoPerThread, WoPerThread>,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun, 1,
true>(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, true>(
make_multi_index(0, b_e_n_ho_wo_global_desc,
0, make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
n_block_work_id,
ho_block_work_id,
ho_thread_id,
0,
wo_block_work_id,
wo_thread_id,
0,
0));
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_shared_block, a_e0_e1_k0_k1_e2_block_copy_desc.GetElementSpaceSize()); p_shared_block, a_e_k_desc.GetElementSpaceSize());
// register allocation for output // register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAcc, FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), c_k_n_ho_wo_thread_desc.GetElementSpaceSize(),
true> true>
c_thread_buf; c_thread_buf;
#if 0
// initialize output thread tensor // initialize output thread tensor
ThreadwiseTensorSliceSet_v1<FloatAcc, ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_k1_n_h2_w2_thread_gemm_desc), decltype(c_k_n_ho_wo_thread_desc),
Sequence<KPerThread, NPerBlock, HoPerThread, WoPerThread>>{} Sequence<KPerThread, 1, HoPerThread, WoPerThread>>{}
.Run(c_k1_n_h2_w2_thread_gemm_desc, .Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
make_tuple(I0, I0, I0, I0),
c_thread_buf,
FloatAcc{0});
#endif
constexpr auto b_thread_slice_copy_step = constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
make_multi_index(0, E1PerBlock, 0, 0, 0, 0, 0, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy // hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_e0_e1_k_e2_global_step_hacks = AGlobalStepHacks{}; constexpr auto a_e_k_global_step_hacks = AGlobalStepHacks{};
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks = BGlobalStepHacks{}; constexpr auto b_e_n_ho_wo_global_step_hacks = BGlobalStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_e_k_global_move_slice_window_step_hack = AGlobalMoveSliceWindowStepHacks{};
constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
BGlobalMoveSliceWindowStepHacks{};
// double regsiter buffer for b // double regsiter buffer for b
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAB, FloatAB,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc.GetElementSpaceSize(), b_e_n_ho_wo_thread_desc.GetElementSpaceSize(),
true> true>
b_thread_even_buf, b_thread_odd_buf; b_thread_even_buf, b_thread_odd_buf;
if constexpr(HasMainE0BlockLoop)
{
const auto E0 = b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetLength(I0);
index_t e0_block_data_begin = 0;
do
{
// LDS double buffer: preload data // LDS double buffer: preload data
{ {
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_step_hacks);
a_e0_e1_k0_k1_e2_grid_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks);
b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
b_global_buf, b_global_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_even_buf, b_thread_even_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); b_e_n_ho_wo_global_step_hacks);
a_blockwise_copy.RunWrite(a_e0_e1_k0_k1_e2_block_copy_desc, a_block_buf); a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf);
} }
__syncthreads(); __syncthreads();
if constexpr(HasMainE1BlockLoop) if constexpr(HasMainKBlockLoop)
{ {
index_t e1_block_data_begin = 0; index_t e_block_data_begin = 0;
// LDS double buffer: main body // LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow // use Do-While loop instead of For loop to simplify control flow
do do
{ {
// even iteration // even iteration
b_threadwise_transfer.MoveSrcSliceWindow( b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, b_thread_slice_copy_step);
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run( b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf, b_global_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_odd_buf, b_thread_odd_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); b_e_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
b_threadwise_transfer.MoveSrcSliceWindow( b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, b_thread_slice_copy_step);
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run( b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf, b_global_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_even_buf, b_thread_even_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); b_e_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
e1_block_data_begin += 2 * E1PerBlock; e_block_data_begin += 2 * EPerBlock;
} while(e1_block_data_begin < E1 - 2 * E1PerBlock); } while(e_block_data_begin < E - 2 * EPerBlock);
} }
// LDS double buffer: tail // LDS double buffer: tail
if constexpr(HasDoubleTailE1BlockLoop) // if has 2 iteration left if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{ {
b_threadwise_transfer.MoveSrcSliceWindow( b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, b_thread_slice_copy_step);
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
b_global_buf, b_global_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_odd_buf, b_thread_odd_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); b_e_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
...@@ -647,234 +344,113 @@ struct GridwiseGemmDlops_km_kn_mn_v2 ...@@ -647,234 +344,113 @@ struct GridwiseGemmDlops_km_kn_mn_v2
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
} }
a_blockwise_copy.MoveSrcSliceWindow(a_e0_e1_k0_k1_e2_grid_desc, // output: register to global memory
a_block_slice_copy_step,
AGlobalMoveSliceWindowStepHacks{});
blockwise_gemm.MoveABlockSliceWindow(make_tuple(-(E1 - E1PerBlock), 0, 0));
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
e0_block_data_begin += 1;
} while(e0_block_data_begin < E0);
}
else
{
// LDS double buffer: preload data
{
a_blockwise_copy.RunRead(
a_e0_e1_k0_k1_e2_grid_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks);
b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_even_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
a_blockwise_copy.RunWrite(a_e0_e1_k0_k1_e2_block_copy_desc, a_block_buf);
}
__syncthreads();
if constexpr(HasMainE1BlockLoop)
{
index_t e1_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
b_threadwise_transfer.MoveSrcSliceWindow(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_odd_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0));
b_threadwise_transfer.MoveSrcSliceWindow(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_even_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0));
e1_block_data_begin += 2 * E1PerBlock;
} while(e1_block_data_begin < E1 - 2 * E1PerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailE1BlockLoop) // if has 2 iteration left
{ {
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, // hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
b_thread_slice_copy_step, constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = CGlobalStepHacks{};
BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_odd_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
// LDS double buffer: GEMM on 2nd-last data const index_t k_thread_data_on_global =
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); k_block_data_on_global + k_thread_id * KPerThread;
blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0));
// LDS double buffer: GEMM on last data ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); FloatC,
} decltype(c_k_n_ho_wo_thread_desc),
else // if has 1 iteration left decltype(c_k_n_ho_wo_global_desc),
{ Sequence<KPerThread, 1, HoPerThread, WoPerThread>,
// LDS double buffer: GEMM on last data CThreadTransferSrcDstAccessOrder,
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>(
c_k_n_ho_wo_global_desc,
make_multi_index(
k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global))
.Run(c_k_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
c_k_n_ho_wo_global_desc,
c_global_buf,
c_k_n_ho_wo_global_tensor_step_hacks);
} }
} }
// activ // pass tensor descriptor by reference
{ template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
static_for<0, c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), 1>{}([&](auto i) { __device__ void Run(const AGlobalDesc& a_e_k_global_desc,
if constexpr(activ_type == 1) const FloatAB* __restrict__ p_a_global,
{ const BGlobalDesc& b_e_n_ho_wo_global_desc,
c_thread_buf(i) = const FloatAB* __restrict__ p_b_global,
c_thread_buf[i] >= 0 ? c_thread_buf[i] : alpha * c_thread_buf[i]; const CGlobalDesc& c_k_n_ho_wo_global_desc,
} FloatC* __restrict__ p_c_global,
else if constexpr(activ_type == 2) integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
FloatAcc x = 1.0 + exp(-c_thread_buf[i]); constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
asm volatile("\n \ __shared__ FloatAB p_shared_block[shared_block_size];
v_rcp_f32 %0, %1 \n"
: "=v"(x)
: "0"(x));
c_thread_buf(i) = x; Run(a_e_k_global_desc,
} p_a_global,
}); b_e_n_ho_wo_global_desc,
p_b_global,
c_k_n_ho_wo_global_desc,
p_c_global,
p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
// bias // pass tensor descriptors by their pointers
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc* p_a_e_k_global_desc,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc* p_b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc* p_c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
constexpr auto bias_k0_k1_thread_desc = const auto a_e_k_global_desc = *p_a_e_k_global_desc;
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<KPerThread>{})); const auto b_e_n_ho_wo_global_desc = *p_b_e_n_ho_wo_global_desc;
const auto c_k_n_ho_wo_global_desc = *p_c_k_n_ho_wo_global_desc;
StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatC, Run(a_e_k_global_desc,
bias_k0_k1_thread_desc.GetElementSpaceSize(), p_a_global,
true> b_e_n_ho_wo_global_desc,
bias_thread_buf; p_b_global,
c_k_n_ho_wo_global_desc,
const index_t k_thread_data_on_global = k_thread_id * KPerThread; p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
auto bias_threadwise_transfer = integral_constant<bool, HasDoubleTailKBlockLoop>{});
ThreadwiseTensorSliceTransfer_v2<FloatC,
FloatC,
decltype(bias_k0_k1_grid_desc),
decltype(bias_k0_k1_thread_desc),
Sequence<I1, Number<KPerThread>{}>,
Sequence<0, 1>,
1,
CThreadTransferDstScalarPerVector,
false,
true>(
bias_k0_k1_grid_desc,
make_multi_index(k_block_work_id, k_thread_data_on_global));
constexpr auto bias_k0_k1_global_tensor_step_hacks = make_tuple(
make_tuple(Sequence<0>{}, Sequence<0>{}), make_tuple(Sequence<0>{}, Sequence<0>{}));
bias_threadwise_transfer.Run(bias_k0_k1_grid_desc,
bias_global_buf,
bias_k0_k1_thread_desc,
make_tuple(I0, I0),
bias_thread_buf,
bias_k0_k1_global_tensor_step_hacks);
static_for<0, KPerThread, 1>{}([&](auto ki) {
static_for<0, HoPerThread, 1>{}([&](auto hi) {
static_for<0, WoPerThread, 1>{}([&](auto wi) {
constexpr index_t c_offset = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset(
make_tuple(ki, 0, hi, wi));
c_thread_buf(Number<c_offset>{}) =
c_thread_buf[Number<c_offset>{}] + bias_thread_buf[ki];
});
});
});
} }
// output: register to global memory // pass tensor descriptors by void*
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const void* p_a_e_k_global_desc,
const FloatAB* __restrict__ p_a_global,
const void* p_b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global,
const void* p_c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
// hack to control index calculation when iterating over c_k_n_h0_h1_h2_w0_w1_w2_global const auto a_e_k_global_desc = *reinterpret_cast<const AGlobalDesc*>(p_a_e_k_global_desc);
// tensor const auto b_e_n_ho_wo_global_desc =
constexpr auto c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks = CGlobalStepHacks{}; *reinterpret_cast<const BGlobalDesc*>(p_b_e_n_ho_wo_global_desc);
const auto c_k_n_ho_wo_global_desc =
constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc = *reinterpret_cast<const CGlobalDesc*>(p_c_k_n_ho_wo_global_desc);
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<KPerThread>{}, Run(a_e_k_global_desc,
I1, p_a_global,
I1, b_e_n_ho_wo_global_desc,
I1, p_b_global,
Number<HoPerThread>{}, c_k_n_ho_wo_global_desc,
I1, p_c_global,
I1, integral_constant<bool, HasMainKBlockLoop>{},
Number<WoPerThread>{})); integral_constant<bool, HasDoubleTailKBlockLoop>{});
const index_t k_thread_data_on_global = k_thread_id * KPerThread;
ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc),
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc),
Sequence<I1, KPerThread, I1, I1, I1, HoPerThread, I1, I1, WoPerThread>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
make_multi_index(k_block_work_id,
k_thread_data_on_global,
n_block_work_id,
ho_block_work_id,
ho_thread_id,
0,
wo_block_work_id,
wo_thread_id,
0))
.Run(c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
c_global_buf,
c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks);
}
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment