"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "b2c43ffd4ce8db4cf8c6516c89775239c28a5464"
Unverified Commit b8b2d0a6 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

DL GEMM fp32/fp16/int8 (#41)

* add threadwise copy the copy a tensor in one copy, added kpack to DL GEMM

* add kpack into fwd v4r5 nchw fp32
parent 11ec07e9
#ifndef CK_DRIVER_DYNAMIC_CONTRACTION_V1R2_HPP
#define CK_DRIVER_DYNAMIC_CONTRACTION_V1R2_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_contraction_v1r2.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGKGM0GM1GridDesc,
typename BGKGN0GN1GridDesc,
typename CGM0GM1GN0GN1GridDesc,
index_t GM1PerBlockGM11,
index_t GN1PerBlockGN11,
index_t KPerBlock,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
index_t M1N1ThreadClusterM10,
index_t M1N1ThreadClusterN10,
index_t M1N1ThreadClusterM11,
index_t M1N1ThreadClusterN11,
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks,
typename BGridIteratorHacks,
typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks>
__host__ float
driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
const AGKGM0GM1GridDesc& a_gk0_gm0_gm1_gk1_grid_desc,
const BGKGN0GN1GridDesc& b_gk0_gn0_gn1_gk1_grid_desc,
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
// GEMM
using GridwiseContraction = GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AGKGM0GM1GridDesc,
BGKGN0GN1GridDesc,
CGM0GM1GN0GN1GridDesc,
GM1PerBlockGM11,
GN1PerBlockGN11,
KPerBlock,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>;
const auto GK0 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I0);
if(!GridwiseContraction::CheckValidity(
a_gk0_gm0_gm1_gk1_grid_desc, b_gk0_gn0_gn1_gk1_grid_desc, c_gm0_gm1_gn0_gn1_grid_desc))
{
throw std::runtime_error(
"wrong! GridwiseDynamicContraction_km_kn0n1_mn0n1_v1r1 has invalid setting");
}
const auto a_gk0_gm0_gm10_gm11_gk1_grid_desc =
GridwiseContraction::MakeAGK0GM0GM10GM11GK1GridDescriptor(a_gk0_gm0_gm1_gk1_grid_desc);
const auto b_gk0_gn0_gn10_gn11_gk1_grid_desc =
GridwiseContraction::MakeBGK0GN0GN10GN11GK1GridDescriptor(b_gk0_gn0_gn1_gk1_grid_desc);
using AGK0GM0GM10GM11GK1GridDesc = decltype(a_gk0_gm0_gm10_gm11_gk1_grid_desc);
using BGK0GN0GN10GN11GK1GridDesc = decltype(b_gk0_gn0_gn10_gn11_gk1_grid_desc);
// c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc
const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc =
GridwiseContraction::MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(c_gm0_gm1_gn0_gn1_grid_desc);
using CGM10BM0BM1GN10BN0BN1GridDesc = decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc);
// c_blockid_to_gm10_gn10_block_cluster_adaptor
const auto c_blockid_to_gm10_gn10_block_cluster_adaptor =
GridwiseContraction::MakeCBlockIdToGM10GN10BlockClusterAdaptor(c_gm0_gm1_gn0_gn1_grid_desc);
using CBlockIdToGM10GN10BlockClusterAdaptor =
decltype(c_blockid_to_gm10_gn10_block_cluster_adaptor);
const index_t grid_size = GridwiseContraction::CalculateGridSize(c_gm0_gm1_gn0_gn1_grid_desc);
const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(GK0);
const bool has_double_tail_k_block_loop =
GridwiseContraction::CalculateHasDoubleTailKBlockLoop(GK0);
{
std::cout << "a_gk0_gm0_gm10_gm11_gk1_grid_desc{"
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I0) << ", "
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I1) << ", "
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I2) << ", "
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I3) << ", "
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I4) << "}" << std::endl;
std::cout << "b_gk0_gn0_gn10_gn11_gk1_grid_desc{"
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I0) << ", "
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I1) << ", "
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I2) << ", "
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I3) << ", "
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I4) << "}" << std::endl;
std::cout << "c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc{ "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I0) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I1) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I2) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I3) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I4) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I5) << "}" << std::endl;
}
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_contraction_v1r1<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>,
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
true,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_contraction_v1r1<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>,
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
true,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_contraction_v1r1<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>,
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
false,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
}
else
{
const auto kernel = kernel_dynamic_contraction_v1r1<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>,
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
false,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
}
return ave_time;
}
} // namespace ck
#endif
This diff is collapsed.
...@@ -12,7 +12,7 @@ namespace ck { ...@@ -12,7 +12,7 @@ namespace ck {
// C: out // C: out
// GemmM = N * Ho * Wo // GemmM = N * Ho * Wo
// GemmN = K // GemmN = K
// GemmK = C * Y * X // GemmK = Y * X * C
template <typename... In, template <typename... In,
typename... Wei, typename... Wei,
typename... Out, typename... Out,
......
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5R2_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5R2_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t N0Value,
index_t C0Value>
__host__ __device__ constexpr auto
transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<N0Value>,
Number<C0Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3);
const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2);
const auto X = wei_k_c_y_x_grid_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
constexpr auto N0 = Number<N0Value>{};
constexpr auto C0 = Number<C0Value>{};
const auto N1 = N / N0;
const auto C1 = C / C0;
// weight tensor
const auto wei_gk0_gm0_gm1_gk1_grid_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
make_tuple(make_unmerge_transform(make_tuple(I1, K)),
make_unmerge_transform(make_tuple(C0, C1 * Y * X))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1, 2>{}, Sequence<3, 0>{}));
// input tensor
const auto in_n_c_hip_wip_grid_desc = transform_dynamic_tensor_descriptor(
in_n_c_hi_wi_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n0_n1_c0_c1_y_ho_x_wo_grid_desc = transform_dynamic_tensor_descriptor(
in_n_c_hip_wip_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
make_unmerge_transform(make_tuple(C0, C1)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
const auto in_gk0_gn0_gn1_gk1_grid_desc = transform_dynamic_tensor_descriptor(
in_n0_n1_c0_c1_y_ho_x_wo_grid_desc,
make_tuple(make_merge_transform(make_tuple(C1, Y, X)),
make_pass_through_transform(N0),
make_merge_transform(make_tuple(N1, Ho, Wo)),
make_pass_through_transform(C0)),
make_tuple(Sequence<3, 4, 6>{}, Sequence<0>{}, Sequence<1, 5, 7>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// output tensor
const auto out_n_k_howo_grid_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo));
const auto out_n0_n1_1_k_howo_grid_desc = transform_dynamic_tensor_descriptor(
out_n_k_howo_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(Number<N0>{}, N1)),
make_unmerge_transform(make_tuple(I1, K)),
make_pass_through_transform(Ho * Wo)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{}));
const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_dynamic_tensor_descriptor(
out_n0_n1_1_k_howo_grid_desc,
make_tuple(make_pass_through_transform(I1),
make_pass_through_transform(K),
make_pass_through_transform(Number<N0>{}),
make_merge_transform_v2_magic_division(make_tuple(N1, Ho * Wo))),
make_tuple(Sequence<2>{}, Sequence<3>{}, Sequence<0>{}, Sequence<1, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
return make_tuple(
wei_gk0_gm0_gm1_gk1_grid_desc, in_gk0_gn0_gn1_gk1_grid_desc, out_gm0_gm1_gn0_gn1_grid_desc);
}
} // namespace ck
#endif
#ifndef CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP
#define CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp"
namespace ck {
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize,
InMemoryDataOperation DstInMemOp,
typename BlockSliceLengths,
typename ThreadSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
typename SrcVectorTensorLengths,
typename DstVectorTensorLengths,
typename SrcVectorTensorContiguousDimOrder,
typename DstVectorTensorContiguousDimOrder,
bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseDynamicTensorSliceTransfer_v4r1
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseDynamicTensorSliceTransfer_v4r1(
const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin)
: threadwise_transfer_(
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
{
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{};
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_idx_begin);
}
}
template <typename SrcBuffer, typename SrcIteratorHacks>
__device__ void RunRead(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const SrcIteratorHacks& src_iterator_hacks)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks);
}
}
template <typename DstBuffer>
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunWrite(dst_desc, dst_buf);
}
}
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
}
}
// SrcMoveSliceWindowIteratorHack to control index calculation move slice window
template <typename SrcMoveSliceWindowIteratorHack>
__device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& step,
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(
src_desc, step, src_move_slice_window_iterator_hack);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseDynamicTensorSliceTransfer_v3r1<ThreadSliceLengths,
DstInMemOp,
SrcData,
DstData,
SrcDesc,
DstDesc,
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorTensorLengths,
DstVectorTensorLengths,
SrcVectorTensorContiguousDimOrder,
DstVectorTensorContiguousDimOrder,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
#endif
#ifndef CK_BLOCKWISE_GEMM_V2R3_HPP
#define CK_BLOCKWISE_GEMM_V2R3_HPP
#include "common_header.hpp"
#include "tensor_adaptor.hpp"
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp"
#include "threadwise_gemm_v2.hpp"
namespace ck {
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
// A and B are visable to the whole block, C is distributed among each thread
// Assume:
// 1. A:
// 1. AK0MK1BlockDesc is known at compile-time
// 2. ABlockBuffer is DynamicBuffer
// 2. B:
// 1. BK0NK1BlockDesc is known at compile-time
// 2. BBlockBuffer is DynamicBuffer
// 3. C:
// 1. CM0M1N0N1ThreadDesc is known at compile-time
// 2. CThreadBuffer is StaticBuffer
// Also assume:
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc,
index_t M1PerThreadM11,
index_t N1PerThreadN11,
index_t KPerThread,
index_t M1N1ThreadClusterM100,
index_t M1N1ThreadClusterN100,
index_t M1N1ThreadClusterM101,
index_t M1N1ThreadClusterN101,
index_t AThreadCopyScalarPerVector_M11,
index_t BThreadCopyScalarPerVector_N11,
typename std::enable_if<AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2
{
using AIndex = MultiIndex<3>;
using BIndex = MultiIndex<3>;
using CIndex = MultiIndex<4>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr index_t K0 = AK0MK1BlockDesc{}.GetLength(I0);
static constexpr index_t K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t M = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t N = BK0NK1BlockDesc{}.GetLength(I1);
static constexpr index_t M100 = M1N1ThreadClusterM100;
static constexpr index_t N100 = M1N1ThreadClusterN100;
static constexpr index_t M101 = M1N1ThreadClusterM101;
static constexpr index_t N101 = M1N1ThreadClusterN101;
static constexpr index_t M11 = M1PerThreadM11;
static constexpr index_t N11 = N1PerThreadN11;
static constexpr index_t M1 = M1N1ThreadClusterM100 * M1N1ThreadClusterM101 * M1PerThreadM11;
static constexpr index_t N1 = M1N1ThreadClusterN100 * M1N1ThreadClusterN101 * N1PerThreadN11;
static constexpr index_t M0 = M / M1;
static constexpr index_t N0 = N / N1;
__host__ __device__ static constexpr auto
MakeAK0M0M1K1BlockDescriptor(const AK0MK1BlockDesc& a_k0_m_k1_block_desc)
{
const auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor(
a_k0_m_k1_block_desc,
make_tuple(make_pass_through_transform(Number<K0>{}),
make_unmerge_transform(make_tuple(Number<M0>{}, Number<M1>{})),
make_pass_through_transform(Number<K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return a_k0_m0_m1_k1_block_desc;
}
__host__ __device__ static constexpr auto
MakeBK0N0N1K1BlockDescriptor(const BK0NK1BlockDesc& b_k0_n_k1_block_desc)
{
const auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor(
b_k0_n_k1_block_desc,
make_tuple(make_pass_through_transform(Number<K0>{}),
make_unmerge_transform(make_tuple(Number<N0>{}, Number<N1>{})),
make_pass_through_transform(Number<K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return b_k0_n0_n1_k1_block_desc;
}
__host__ __device__ static constexpr auto MakeCM0M100M101M11N0N100N101N11ToMNBlockAdaptor()
{
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
// lower: [M, N]
constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(
Number<M0>{}, Number<M100>{}, Number<M101>{}, Number<M11>{})),
make_unmerge_transform(make_tuple(
Number<N0>{}, Number<N100>{}, Number<N101>{}, Number<N11>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{}));
return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor;
}
__host__ __device__ static constexpr auto
MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor()
{
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
// lower: [M0, M1, N0, N1]
constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_pass_through_transform(Number<M0>{}),
make_unmerge_transform(
make_tuple(Number<M100>{}, Number<M101>{}, Number<M11>{})),
make_pass_through_transform(Number<N0>{}),
make_unmerge_transform(
make_tuple(Number<N100>{}, Number<N101>{}, Number<N11>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{}));
return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor;
}
__host__ __device__ static constexpr auto GetCM0M1N0N1ThreadTensorLengths()
{
return Sequence<M0, M11, N0, N11>{};
}
static constexpr auto a_k0_m0_m1_k1_block_desc_ =
MakeAK0M0M1K1BlockDescriptor(AK0MK1BlockDesc{});
static constexpr auto b_k0_n0_n1_k1_block_desc_ =
MakeBK0N0N1K1BlockDescriptor(BK0NK1BlockDesc{});
public:
__device__ BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2()
: c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginOnBlock(
get_thread_local_1d_id())},
a_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1], 0)},
b_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3], 0)}
{
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(BlockSize == M101 * M100 * N101 * N100,
"wrong! blocksize and cluster size not consistent");
static_assert(M % M1 == 0 && N % N1 == 0, "wrong!");
static_assert(AK0MK1BlockDesc{}.GetLength(I0) == BK0NK1BlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent");
// TODO: remove this restriction
static_assert(M0 == 2 && N0 == 2, "wrong");
}
__device__ static CIndex CalculateCM0M1N0N1ThreadOriginOnBlock(index_t thread_id)
{
// lower: [M0, M1, N0, N1]
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
constexpr auto adaptor0 = MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor();
// lower: [M0, M100, M101, M11, N0, N100, N101, N11]
// upper: [Tid, M0, M11, N0, N11]
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M100, N100, M101, N101)),
make_pass_through_transform(M0),
make_pass_through_transform(M11),
make_pass_through_transform(N0),
make_pass_through_transform(N11)),
make_tuple(
Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1);
return adaptor.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id(), 0, 0, 0, 0));
}
template <typename CM0M1N0N1ThreadDesc,
typename ABlockBuffer,
typename BBlockBuffer,
typename CThreadBuffer>
__device__ void Run(const CM0M1N0N1ThreadDesc& c_m0_m1_n0_n1_thread_desc,
const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
static_assert(CM0M1N0N1ThreadDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
// TODO: remove this restriction
static_assert(M0 == 2 && N0 == 2 && CM0M1N0N1ThreadDesc{}.GetLength(I0) == M0 &&
CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0,
"wrong");
auto a_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatA>(
a_k0_m0_m1_k1_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatB>(
b_k0_n0_n1_k1_thread_desc_.GetElementSpaceSize());
constexpr auto threadwise_gemm =
ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1<FloatA,
FloatB,
FloatC,
decltype(a_k0_m0_m1_k1_thread_desc_),
decltype(b_k0_n0_n1_k1_thread_desc_),
CM0M1N0N1ThreadDesc,
Sequence<KPerThread, K1>,
Sequence<1, M1PerThreadM11>,
Sequence<1, N1PerThreadN11>>{};
// read A_sub_0
a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_,
make_tuple(I0, I0, I0, I0),
a_block_buf,
a_k0_m0_m1_k1_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
// read B_sub_0
b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_,
make_tuple(I0, I0, I0, I0),
b_block_buf,
b_k0_n0_n1_k1_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
// read B_sub_1
b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_,
make_tuple(I0, I1, I0, I0),
b_block_buf,
b_k0_n0_n1_k1_thread_desc_,
make_tuple(I0, I1, I0, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_,
make_tuple(I0, I1, I0, I0),
a_block_buf,
a_k0_m0_m1_k1_thread_desc_,
make_tuple(I0, I1, I0, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I1, I0));
// loop over rest of k
static_for<KPerThread, K0, KPerThread>{}([&](auto k) {
// read A_sub_0
a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_,
make_tuple(k, I0, I0, I0),
a_block_buf,
a_k0_m0_m1_k1_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I0, I0));
// read B_sub_0
b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_,
make_tuple(k, I0, I0, I0),
b_block_buf,
b_k0_n0_n1_k1_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I1, I0));
// read B_sub_1
b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_,
make_tuple(k, I1, I0, I0),
b_block_buf,
b_k0_n0_n1_k1_thread_desc_,
make_tuple(I0, I1, I0, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_,
make_tuple(k, I1, I0, I0),
a_block_buf,
a_k0_m0_m1_k1_thread_desc_,
make_tuple(I0, I1, I0, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I1, I0));
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I0, I0));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I1, I0));
}
private:
// A[K0, M0, M1, K1]
static constexpr auto a_k0_m0_m1_k1_thread_desc_ =
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThread>{}, Number<M0>{}, Number<M1PerThreadM11>{}, Number<K1>{}));
// B[K0, N0, N1, K1]
static constexpr auto b_k0_n0_n1_k1_thread_desc_ =
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThread>{}, Number<N0>{}, Number<N1PerThreadN11>{}, Number<K1>{}));
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1<
FloatA,
FloatA,
decltype(a_k0_m0_m1_k1_block_desc_),
decltype(a_k0_m0_m1_k1_thread_desc_),
Sequence<KPerThread, 1, M1PerThreadM11, K1>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
Sequence<1, 1, M1PerThreadM11, K1>, // SrcVectorTensorLengths
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1<
FloatB,
FloatB,
decltype(b_k0_n0_n1_k1_block_desc_),
decltype(b_k0_n0_n1_k1_thread_desc_),
Sequence<KPerThread, 1, N1PerThreadN11, K1>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
Sequence<1, 1, N1PerThreadN11, K1>, // SrcVectorTensorLengths
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
CIndex c_thread_origin_data_idx_;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
} // namespace ck
#endif
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
namespace ck { namespace ck {
// C[M, N] += transpose(A[K, M]) * B[K, N] // C[M0, M1, N0, N1] += A[K, M0, M1] * B[K, N0, N1]
// Element of matrix can be vectorized data // Tensor element can be vectorized data
// Assume: // Assume:
// 1. ADesc, BDesc, CDesc are known at compile-time // 1. ADesc, BDesc, CDesc are known at compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time // 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
...@@ -17,11 +17,27 @@ template <typename FloatA, ...@@ -17,11 +17,27 @@ template <typename FloatA,
typename ADesc, typename ADesc,
typename BDesc, typename BDesc,
typename CDesc, typename CDesc,
typename KLengths,
typename MLengths,
typename NLengths,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(), CDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseGemm_km_kn_mn_v1r1 struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
{ {
__device__ constexpr ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1()
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
// TODO: sanity-check: compare ADesc, BDesc, CDesc Size with KLenghts, MLengths and NLengths
// TODO remove this restriction
static_assert(KLengths::Size() == 1 && MLengths::Size() == 2 && NLengths::Size() == 2,
"wrong!");
}
template <typename ABuffer, template <typename ABuffer,
typename AOriginIdx, typename AOriginIdx,
typename BBuffer, typename BBuffer,
...@@ -35,10 +51,6 @@ struct ThreadwiseGemm_km_kn_mn_v1r1 ...@@ -35,10 +51,6 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
CBuffer& c_buf, CBuffer& c_buf,
COriginIdx) COriginIdx)
{ {
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert( static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value && is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value && is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value &&
...@@ -58,89 +70,42 @@ struct ThreadwiseGemm_km_kn_mn_v1r1 ...@@ -58,89 +70,42 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto M = CDesc{}.GetLength(I0); constexpr auto K = KLengths{}[I0];
constexpr auto N = CDesc{}.GetLength(I1); constexpr auto M0 = MLengths{}[I0];
constexpr auto K = ADesc{}.GetLength(I0); constexpr auto M1 = MLengths{}[I1];
constexpr auto N0 = NLengths{}[I0];
constexpr auto N1 = NLengths{}[I1];
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, K, 1>{}([&](auto k) { static_for<0, K, 1>{}([&](auto k) {
static_for<0, M, 1>{}([&](auto m) { static_for<0, M0, 1>{}([&](auto m0) {
constexpr index_t a_offset = static_for<0, M1, 1>{}([&](auto m1) {
ADesc{}.CalculateOffset(a_origin_idx + make_tuple(k, m)); static_for<0, N0, 1>{}([&](auto n0) {
static_for<0, N1, 1>{}([&](auto n1) {
#if 0
if constexpr(N == 2)
{
constexpr index_t b_offset_0 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0));
constexpr index_t b_offset_1 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1));
constexpr index_t c_offset_0 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0));
constexpr index_t c_offset_1 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1));
amd_assembly_outer_product_1x2(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset_0>{}],
b_buf[Number<b_offset_1>{}],
c_buf(Number<c_offset_0>{}),
c_buf(Number<c_offset_1>{}));
}
else if constexpr(N == 4)
{
constexpr index_t b_offset_0 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0));
constexpr index_t b_offset_1 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1));
constexpr index_t b_offset_2 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I2));
constexpr index_t b_offset_3 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I3));
constexpr index_t c_offset_0 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0));
constexpr index_t c_offset_1 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1));
constexpr index_t c_offset_2 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I2));
constexpr index_t c_offset_3 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I3));
amd_assembly_outer_product_1x4(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset_0>{}],
b_buf[Number<b_offset_1>{}],
b_buf[Number<b_offset_2>{}],
b_buf[Number<b_offset_3>{}],
c_buf(Number<c_offset_0>{}),
c_buf(Number<c_offset_1>{}),
c_buf(Number<c_offset_2>{}),
c_buf(Number<c_offset_3>{}));
}
else
#endif
{
static_for<0, N, 1>{}([&](auto n) {
constexpr index_t a_offset =
ADesc{}.CalculateOffset(a_origin_idx + make_multi_index(k, m0, m1));
constexpr index_t b_offset = constexpr index_t b_offset =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, n)); BDesc{}.CalculateOffset(b_origin_idx + make_multi_index(k, n0, n1));
constexpr index_t c_offset = constexpr index_t c_offset = CDesc{}.CalculateOffset(
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, n)); c_origin_idx + make_multi_index(m0, m1, n0, n1));
amd_assembly_inner_product(a_buf[Number<a_offset>{}], amd_inner_product_dlop<FloatA, FloatB, FloatC>(
a_buf[Number<a_offset>{}],
b_buf[Number<b_offset>{}], b_buf[Number<b_offset>{}],
c_buf(Number<c_offset>{})); c_buf(Number<c_offset>{}));
}); });
} });
});
}); });
}); });
} }
}; };
// C[M0, M1, N0, N1] += A[K, M0, M1] * B[K, N0, N1] // C[M0, M1, N0, N1] += A[K0, M0, M1, K1] * B[K0, N0, N1, K1]
// Tensor element can be vectorized data // Tensor element can be vectorized data
// Assume: // Assume:
// 1. ADesc, BDesc, CDesc are known at compile-time // 1. ADesc, BDesc, CDesc are known at compile-time
...@@ -157,9 +122,9 @@ template <typename FloatA, ...@@ -157,9 +122,9 @@ template <typename FloatA,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(), CDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1 struct ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1
{ {
__device__ constexpr ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1() __device__ constexpr ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1()
{ {
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(), CDesc::IsKnownAtCompileTime(),
...@@ -168,7 +133,7 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1 ...@@ -168,7 +133,7 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
// TODO: sanity-check: compare ADesc, BDesc, CDesc Size with KLenghts, MLengths and NLengths // TODO: sanity-check: compare ADesc, BDesc, CDesc Size with KLenghts, MLengths and NLengths
// TODO remove this restriction // TODO remove this restriction
static_assert(KLengths::Size() == 1 && MLengths::Size() == 2 && NLengths::Size() == 2, static_assert(KLengths::Size() == 2 && MLengths::Size() == 2 && NLengths::Size() == 2,
"wrong!"); "wrong!");
} }
...@@ -204,31 +169,47 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1 ...@@ -204,31 +169,47 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto K = KLengths{}[I0]; constexpr index_t K0 = KLengths{}[I0];
constexpr auto M0 = MLengths{}[I0]; constexpr index_t K1 = KLengths{}[I1];
constexpr auto M1 = MLengths{}[I1]; constexpr index_t M0 = MLengths{}[I0];
constexpr auto N0 = NLengths{}[I0]; constexpr index_t M1 = MLengths{}[I1];
constexpr auto N1 = NLengths{}[I1]; constexpr index_t N0 = NLengths{}[I0];
constexpr index_t N1 = NLengths{}[I1];
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, K, 1>{}([&](auto k) { static_for<0, K0, 1>{}([&](auto k0) {
static_for<0, M0, 1>{}([&](auto m0) { static_for<0, M0, 1>{}([&](auto m0) {
static_for<0, M1, 1>{}([&](auto m1) { static_for<0, M1, 1>{}([&](auto m1) {
static_for<0, N0, 1>{}([&](auto n0) { static_for<0, N0, 1>{}([&](auto n0) {
static_for<0, N1, 1>{}([&](auto n1) { static_for<0, N1, 1>{}([&](auto n1) {
constexpr index_t a_offset = vector_type<FloatA, K1> a_vec;
ADesc{}.CalculateOffset(a_origin_idx + make_multi_index(k, m0, m1)); vector_type<FloatB, K1> b_vec;
constexpr index_t b_offset =
BDesc{}.CalculateOffset(b_origin_idx + make_multi_index(k, n0, n1)); static_for<0, K1, 1>{}([&](auto k1) {
constexpr index_t a_offset = ADesc{}.CalculateOffset(
a_origin_idx + make_multi_index(k0, m0, m1, k1));
constexpr index_t b_offset = BDesc{}.CalculateOffset(
b_origin_idx + make_multi_index(k0, n0, n1, k1));
a_vec.template AsType<FloatA>()(k1) = a_buf[Number<a_offset>{}];
b_vec.template AsType<FloatB>()(k1) = b_buf[Number<b_offset>{}];
});
using a_vector_t = typename vector_type<FloatA, K1>::type;
using b_vector_t = typename vector_type<FloatB, K1>::type;
constexpr index_t c_offset = CDesc{}.CalculateOffset( constexpr index_t c_offset = CDesc{}.CalculateOffset(
c_origin_idx + make_multi_index(m0, m1, n0, n1)); c_origin_idx + make_multi_index(m0, m1, n0, n1));
amd_assembly_inner_product(a_buf[Number<a_offset>{}], amd_inner_product_dlop<a_vector_t, b_vector_t, FloatC>(
b_buf[Number<b_offset>{}], a_vec.template AsType<a_vector_t>()[I0],
b_vec.template AsType<b_vector_t>()[I0],
c_buf(Number<c_offset>{})); c_buf(Number<c_offset>{}));
}); });
}); });
......
#ifndef CK_AMD_DLOP_HPP
#define CK_AMD_DLOP_HPP
#include "float_type.hpp"
namespace ck {
template <typename TA, typename TB, typename TC>
__device__ void amd_inner_product_dlop(const TA& a, const TB& b, TC& c);
template <>
__device__ void
amd_inner_product_dlop<float, float, float>(const float& a, const float& b, float& c)
{
#if CK_USE_AMD_DLOP_INLINE_ASM
asm volatile("\n \
v_fmac_f32 %0, %1, %2 \n \
"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
#else
c += a * b;
#endif
}
#if CK_USE_AMD_DLOP
template <>
__device__ void
amd_inner_product_dlop<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c)
{
#if CK_USE_AMD_DLOP_INLINE_ASM
asm volatile("\n \
v_dot2_f32_f16 %0, %1, %2, %0\n \
"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
#else
c = __builtin_amdgcn_sdot2(a, b, c, false);
#endif
}
template <>
__device__ void
amd_inner_product_dlop<half4_t, half4_t, float>(const half4_t& a, const half4_t& b, float& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
amd_inner_product_dlop(vector_type<half_t, 4>{a}.AsType<half2_t>()[I0],
vector_type<half_t, 4>{b}.AsType<half2_t>()[I0],
c);
amd_inner_product_dlop(vector_type<half_t, 4>{a}.AsType<half2_t>()[I1],
vector_type<half_t, 4>{b}.AsType<half2_t>()[I1],
c);
}
template <>
__device__ void
amd_inner_product_dlop<half8_t, half8_t, float>(const half8_t& a, const half8_t& b, float& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
amd_inner_product_dlop(vector_type<half_t, 8>{a}.AsType<half2_t>()[I0],
vector_type<half_t, 8>{b}.AsType<half2_t>()[I0],
c);
amd_inner_product_dlop(vector_type<half_t, 8>{a}.AsType<half2_t>()[I1],
vector_type<half_t, 8>{b}.AsType<half2_t>()[I1],
c);
amd_inner_product_dlop(vector_type<half_t, 8>{a}.AsType<half2_t>()[I2],
vector_type<half_t, 8>{b}.AsType<half2_t>()[I2],
c);
amd_inner_product_dlop(vector_type<half_t, 8>{a}.AsType<half2_t>()[I3],
vector_type<half_t, 8>{b}.AsType<half2_t>()[I3],
c);
}
template <>
__device__ void amd_inner_product_dlop<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a,
const int8x4_t& b,
int32_t& c)
{
#if CK_USE_AMD_DLOP_INLINE_ASM
asm volatile("\n \
v_dot4_i32_i8 %0, %1, %2, %0\n \
"
: "=v"(c)
: "v"(as_type<int32_t>(a)), "v"(as_type<int32_t>(b)), "0"(c));
#else
c = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b), c, false);
#endif
}
template <>
__device__ void amd_inner_product_dlop<int8x8_t, int8x8_t, int32_t>(const int8x8_t& a,
const int8x8_t& b,
int32_t& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
amd_inner_product_dlop(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I0],
c);
amd_inner_product_dlop(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I1],
c);
}
template <>
__device__ void amd_inner_product_dlop<int8x16_t, int8x16_t, int32_t>(const int8x16_t& a,
const int8x16_t& b,
int32_t& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
amd_inner_product_dlop(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I0],
c);
amd_inner_product_dlop(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I1],
c);
amd_inner_product_dlop(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I2],
c);
amd_inner_product_dlop(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I3],
c);
}
#endif // CK_USE_AMD_DLOP
} // namespace ck
#endif
...@@ -5,94 +5,16 @@ ...@@ -5,94 +5,16 @@
namespace ck { namespace ck {
// c += inner_product(a, b)
__device__ void amd_assembly_inner_product(const float& a, const float& b, float& c)
{
#if CK_USE_AMD_V_FMAC_F32
asm volatile("\n \
v_fmac_f32 %0, %1, %2 \n \
"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
#else
asm volatile("\n \
v_mac_f32 %0, %1, %2 \n \
"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
#endif
}
__device__ void amd_assembly_inner_product(const int8x4_t& a, const int8x4_t& b, int32_t& c)
{
#if 1
asm volatile("\n \
v_dot4_i32_i8 %0, %1, %2, %0\n \
"
: "=v"(c)
: "v"(as_type<int32_t>(a)), "v"(as_type<int32_t>(b)), "0"(c));
#else
c = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b), c, false);
#endif
}
__device__ void amd_assembly_inner_product(const int8x8_t& a, const int8x8_t& b, int32_t& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
amd_assembly_inner_product(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I0],
c);
amd_assembly_inner_product(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I1],
c);
}
__device__ void amd_assembly_inner_product(const int8x16_t& a, const int8x16_t& b, int32_t& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
amd_assembly_inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I0],
c);
amd_assembly_inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I1],
c);
amd_assembly_inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I2],
c);
amd_assembly_inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I3],
c);
}
// c0 += inner_product(a, b0) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1) // c1 += inner_product(a, b1)
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1) __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
{ {
#if CK_USE_AMD_V_FMAC_F32
asm volatile("\n \ asm volatile("\n \
v_fmac_f32 %0, %2, %3 \n \ v_fmac_f32 %0, %2, %3 \n \
v_fmac_f32 %1, %2, %4 \n \ v_fmac_f32 %1, %2, %4 \n \
" "
: "=v"(c0), "=v"(c1) : "=v"(c0), "=v"(c1)
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1)); : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
#else
asm volatile("\n \
v_mac_f32 %0, %2, %3 \n \
v_mac_f32 %1, %2, %4 \n \
"
: "=v"(c0), "=v"(c1)
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
#endif
} }
// c0 += inner_product(a, b0) // c0 += inner_product(a, b0)
...@@ -102,7 +24,6 @@ __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, floa ...@@ -102,7 +24,6 @@ __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, floa
__device__ void amd_assembly_outer_product_1x4( __device__ void amd_assembly_outer_product_1x4(
float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3) float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3)
{ {
#if CK_USE_AMD_V_FMAC_F32
asm volatile("\n \ asm volatile("\n \
v_fmac_f32 %0, %4, %5 \n \ v_fmac_f32 %0, %4, %5 \n \
v_fmac_f32 %1, %4, %6 \n \ v_fmac_f32 %1, %4, %6 \n \
...@@ -111,16 +32,6 @@ __device__ void amd_assembly_outer_product_1x4( ...@@ -111,16 +32,6 @@ __device__ void amd_assembly_outer_product_1x4(
" "
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3)); : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
#else
asm volatile("\n \
v_mac_f32 %0, %4, %5 \n \
v_mac_f32 %1, %4, %6 \n \
v_mac_f32 %2, %4, %7 \n \
v_mac_f32 %3, %4, %8 \n \
"
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
#endif
} }
// c0 += inner_product(a, b0) // c0 += inner_product(a, b0)
......
...@@ -28,10 +28,15 @@ ...@@ -28,10 +28,15 @@
#include "static_buffer.hpp" #include "static_buffer.hpp"
#include "dynamic_buffer.hpp" #include "dynamic_buffer.hpp"
// TODO: remove this
#if CK_USE_AMD_INLINE_ASM #if CK_USE_AMD_INLINE_ASM
#include "amd_inline_asm.hpp" #include "amd_inline_asm.hpp"
#endif #endif
#if CK_USE_AMD_DLOP
#include "amd_dlop.hpp"
#endif
#if CK_USE_AMD_XDLOPS #if CK_USE_AMD_XDLOPS
#include "amd_xdlops.hpp" #include "amd_xdlops.hpp"
#include "amd_xdlops_inline_asm.hpp" #include "amd_xdlops_inline_asm.hpp"
......
...@@ -54,8 +54,13 @@ ...@@ -54,8 +54,13 @@
#define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 1 #define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 1
#endif #endif
#ifndef CK_USE_AMD_V_FMAC_F32 // AMD DLOPS
#define CK_USE_AMD_V_FMAC_F32 1 #ifndef CK_USE_AMD_DLOP
#define CK_USE_AMD_DLOP 1
#endif
#ifndef CK_USE_AMD_DLOP_INLINE_ASM
#define CK_USE_AMD_DLOP_INLINE_ASM 1
#endif #endif
// AMD buffer addressing // AMD buffer addressing
...@@ -116,7 +121,7 @@ ...@@ -116,7 +121,7 @@
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 1 #define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 1
// merge transformation use magic number division // merge transformation use magic number division
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 1 #define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0
// hack: have underlying assumption that need to be satsified, otherwise it's a bug // hack: have underlying assumption that need to be satsified, otherwise it's a bug
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be // hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
......
...@@ -94,7 +94,7 @@ __host__ __device__ constexpr auto container_reorder_given_old2new(Sequence<Is.. ...@@ -94,7 +94,7 @@ __host__ __device__ constexpr auto container_reorder_given_old2new(Sequence<Is..
constexpr auto new2old = typename sequence_map_inverse<Sequence<IRs...>>::type{}; constexpr auto new2old = typename sequence_map_inverse<Sequence<IRs...>>::type{};
return container_reorder_give_new2old(old_seq, new2old); return container_reorder_given_new2old(old_seq, new2old);
} }
#if !CK_WORKAROUND_SWDEV_275126 #if !CK_WORKAROUND_SWDEV_275126
...@@ -223,6 +223,13 @@ container_reverse_exclusive_scan(const Array<TData, NSize>& x, Reduce f, TData i ...@@ -223,6 +223,13 @@ container_reverse_exclusive_scan(const Array<TData, NSize>& x, Reduce f, TData i
return y; return y;
} }
template <index_t... Is, typename Reduce, index_t Init>
__host__ __device__ constexpr auto
container_reverse_exclusive_scan(const Sequence<Is...>& seq, Reduce f, Number<Init>)
{
return reverse_exclusive_scan_sequence(seq, f, Number<Init>{});
}
#if !CK_WORKAROUND_SWDEV_275126 #if !CK_WORKAROUND_SWDEV_275126
// rocm4.1 compiler would crash with recursive lambda // rocm4.1 compiler would crash with recursive lambda
template <typename... Xs, typename Reduce, typename Init> template <typename... Xs, typename Reduce, typename Init>
...@@ -366,6 +373,19 @@ set_container_subset(Tuple<Ys...>& y, Sequence<Is...> picks, const Tuple<Xs...>& ...@@ -366,6 +373,19 @@ set_container_subset(Tuple<Ys...>& y, Sequence<Is...> picks, const Tuple<Xs...>&
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; }); static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
} }
template <typename Container>
__host__ __device__ constexpr auto to_tuple_of_number(const Container&)
{
static_assert(is_known_at_compile_time<Container>::value, "wrong!");
return generate_tuple(
[&](auto i) {
constexpr index_t tmp = Container::At(i);
return Number<tmp>{};
},
Container::Size());
}
template <index_t... Is> template <index_t... Is>
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence<Is...>) __host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence<Is...>)
{ {
......
...@@ -100,13 +100,19 @@ struct DynamicBuffer ...@@ -100,13 +100,19 @@ struct DynamicBuffer
*reinterpret_cast<X*>(&p_data_[i]) = x; *reinterpret_cast<X*>(&p_data_[i]) = x;
#else #else
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into inefficient // HACK: compiler would lower IR "store<i8, 16> address_space(3)" into inefficient
// ISA, so I try to let compiler emit use IR "store<i32, 4>" which would be lower to // ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
// ds_write_b128 // ds_write_b128
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
if constexpr(is_same<typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type, if constexpr(is_same<typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type,
int8_t>::value) int8_t>::value)
{ {
static_assert( static_assert(
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value) ||
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value) ||
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) ||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x4_t>::value && (is_same<remove_cv_t<remove_reference_t<T>>, int8x4_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) || is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) ||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x8_t>::value && (is_same<remove_cv_t<remove_reference_t<T>>, int8x8_t>::value &&
...@@ -115,7 +121,32 @@ struct DynamicBuffer ...@@ -115,7 +121,32 @@ struct DynamicBuffer
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value), is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value),
"wrong! not implemented for this combination, please add implementation"); "wrong! not implemented for this combination, please add implementation");
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8x4_t>::value && if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*reinterpret_cast<int8_t*>(&p_data_[i]) =
*reinterpret_cast<const int8_t*>(&x);
}
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*reinterpret_cast<int16_t*>(&p_data_[i]) =
*reinterpret_cast<const int16_t*>(&x);
}
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*reinterpret_cast<int32_t*>(&p_data_[i]) =
*reinterpret_cast<const int32_t*>(&x);
}
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
int8x4_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
...@@ -123,7 +154,8 @@ struct DynamicBuffer ...@@ -123,7 +154,8 @@ struct DynamicBuffer
*reinterpret_cast<int32_t*>(&p_data_[i]) = *reinterpret_cast<int32_t*>(&p_data_[i]) =
*reinterpret_cast<const int32_t*>(&x); *reinterpret_cast<const int32_t*>(&x);
} }
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8x8_t>::value && else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
int8x8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value) is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
...@@ -131,7 +163,8 @@ struct DynamicBuffer ...@@ -131,7 +163,8 @@ struct DynamicBuffer
*reinterpret_cast<int32x2_t*>(&p_data_[i]) = *reinterpret_cast<int32x2_t*>(&p_data_[i]) =
*reinterpret_cast<const int32x2_t*>(&x); *reinterpret_cast<const int32x2_t*>(&x);
} }
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value && else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
int8x16_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value) is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
......
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp"
...@@ -24,23 +26,27 @@ ...@@ -24,23 +26,27 @@
#define USE_DYNAMIC_MODE 1 #define USE_DYNAMIC_MODE 1
#define USE_CONV_FWD_V4R4_NCHW 0 #define USE_CONV_FWD_V4R4_NCHW 0
#define USE_CONV_FWD_V4R4_NHWC 0 #define USE_CONV_FWD_V4R4_NHWC 0
#define USE_CONV_FWD_V4R4R2_NHWC 0
#define USE_CONV_FWD_V4R5_NCHW 0 #define USE_CONV_FWD_V4R5_NCHW 0
#define USE_CONV_FWD_V4R5R2_NCHW 1
#define USE_CONV_FWD_V5R1_NCHW 0 #define USE_CONV_FWD_V5R1_NCHW 0
#define USE_CONV_FWD_V4R4_XDL_NCHW 0 #define USE_CONV_FWD_V4R4_XDL_NCHW 0
#define USE_CONV_FWD_V4R4R2_XDL_NHWC 0 #define USE_CONV_FWD_V4R4R2_XDL_NHWC 0
#define USE_CONV_FWD_V4R4R3_XDL_NHWC 1 #define USE_CONV_FWD_V4R4R3_XDL_NHWC 0
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1 #define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
enum ConvForwardAlgo enum ConvForwardAlgo
{ {
V4R4NCHW, // 0 V4R4NCHW, // 0
V4R4NHWC, // 1 V4R4NHWC, // 1
V4R5NCHW, // 2 V4R4R2NHWC, // 2
V5R1NCHW, // 3 V4R5NCHW, // 3
V4R4XDLNCHW, // 4 V4R5R2NCHW, // 4
V4R4R2XDLNHWC, // 5 V5R1NCHW, // 5
V4R4R3XDLNHWC, // 6 V4R4XDLNCHW, // 6
V4R4R4XDLNHWC // 7 V4R4R2XDLNHWC, // 7
V4R4R3XDLNHWC, // 8
V4R4R4XDLNHWC // 9
}; };
int main(int argc, char* argv[]) int main(int argc, char* argv[])
...@@ -132,18 +138,15 @@ int main(int argc, char* argv[]) ...@@ -132,18 +138,15 @@ int main(int argc, char* argv[])
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
#endif #endif
#if 0 #if 1
constexpr index_t in_vector_size = 1;
using in_data_t = float; using in_data_t = float;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = float; using out_data_t = float;
#elif 1 #elif 1
constexpr index_t in_vector_size = 1;
using in_data_t = half_t; using in_data_t = half_t;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = half_t; using out_data_t = half_t;
#elif 1 #elif 1
constexpr index_t in_vector_size = 16;
using in_data_t = int8_t; using in_data_t = int8_t;
using acc_data_t = int32_t; using acc_data_t = int32_t;
using out_data_t = int8_t; using out_data_t = int8_t;
...@@ -348,6 +351,33 @@ int main(int argc, char* argv[]) ...@@ -348,6 +351,33 @@ int main(int argc, char* argv[])
} }
#endif #endif
#if USE_CONV_FWD_V4R4R2_NHWC
if(algo == ConvForwardAlgo::V4R4R2NHWC)
{
if(layout != ConvTensorLayout::NHWC)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nhwc();
device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk<in_data_t,
acc_data_t,
out_data_t>(
tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in,
wei,
out_device,
nrepeat);
}
#endif
#if USE_CONV_FWD_V4R5_NCHW #if USE_CONV_FWD_V4R5_NCHW
if(algo == ConvForwardAlgo::V4R5NCHW) if(algo == ConvForwardAlgo::V4R5NCHW)
{ {
...@@ -374,6 +404,33 @@ int main(int argc, char* argv[]) ...@@ -374,6 +404,33 @@ int main(int argc, char* argv[])
} }
#endif #endif
#if USE_CONV_FWD_V4R5R2_NCHW
if(algo == ConvForwardAlgo::V4R5R2NCHW)
{
if(layout != ConvTensorLayout::NCHW)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nchw();
device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw<in_data_t,
acc_data_t,
out_data_t>(
tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in,
wei,
out_device,
nrepeat);
}
#endif
#if USE_CONV_FWD_V5R1_NCHW #if USE_CONV_FWD_V5R1_NCHW
if(algo == ConvForwardAlgo::V5R1NCHW) if(algo == ConvForwardAlgo::V5R1NCHW)
{ {
...@@ -385,7 +442,7 @@ int main(int argc, char* argv[]) ...@@ -385,7 +442,7 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nchw(); const auto tmp = f_make_for_device_nchw();
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw<in_data_t,
in_vector_size, 16,
acc_data_t, acc_data_t,
out_data_t>(tmp[I0], out_data_t>(tmp[I0],
tmp[I1], tmp[I1],
...@@ -525,10 +582,10 @@ int main(int argc, char* argv[]) ...@@ -525,10 +582,10 @@ int main(int argc, char* argv[])
#if 0 #if 0
if(do_log) if(do_log)
{ {
LogRange(std::cout << "in : ", in.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "in : ", in.mData, ",") << std::endl;
LogRange(std::cout << "wei: ", wei.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "wei: ", wei.mData, ",") << std::endl;
LogRange(std::cout << "out_host : ", out_host.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "out_host : ", out_host.mData, ",") << std::endl;
LogRange(std::cout << "out_device: ", out_device.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "out_device: ", out_device.mData, ",") << std::endl;
} }
#endif #endif
} }
......
...@@ -55,7 +55,7 @@ float launch_and_time_kernel(F kernel, ...@@ -55,7 +55,7 @@ float launch_and_time_kernel(F kernel,
{ {
KernelTimer timer; KernelTimer timer;
printf("%s: block_dim {%d, %d, %d}, grid_dim {%d, %d, %d} \n", printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
__func__, __func__,
grid_dim.x, grid_dim.x,
grid_dim.y, grid_dim.y,
......
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
#include "driver_dynamic_gemm_v1r3.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_hi_wi_c,
const Tensor<TInWei>& wei_k_y_x_c,
Tensor<TOut>& out_n_ho_wo_k,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto I8 = Number<8>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
const auto wei_k_y_x_c_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
#if 0
// [M, N, K0, K1] = [128, 128, 8, 1] for fp32
// cdata = 64, BlockSize = 256
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlockM1 = 128;
constexpr index_t GemmNPerBlockN1 = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmK1 = 1;
constexpr index_t GemmM1PerThreadM111 = 4;
constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>;
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>;
using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 1>;
using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 1>;
using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>;
using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 1>;
using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 1>;
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4;
#elif 1
// [M, N, K0, K1] = [128, 128, 8, 2] for fp16
// cdata = 64, BlockSize = 256
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlockM1 = 128;
constexpr index_t GemmNPerBlockN1 = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmK1 = 2;
constexpr index_t GemmM1PerThreadM111 = 4;
constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>;
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>;
using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 2>;
using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 2>;
using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>;
using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 2>;
using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 2>;
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4;
#elif 1
// [M, N, K0, K1] = [128, 128, 8, 4] for i8
// cdata = 64, BlockSize = 256
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlockM1 = 128;
constexpr index_t GemmNPerBlockN1 = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmK1 = 4;
constexpr index_t GemmM1PerThreadM111 = 4;
constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>;
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>;
using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 4>;
using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 4>;
using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>;
using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 4>;
using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 4>;
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4;
#endif
const auto descs =
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc,
wei_k_y_x_c_desc,
out_n_ho_wo_k_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GemmK1>{});
const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GemmM1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmN1
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmN1
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmM0
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM10
Sequence<0, 0, 0, 0, 0>{}, // 2+: GemmM11
Sequence<0, 0, 0, 0, 0>{}, // 3+: GemmN0
Sequence<0, 0, 0, 0, 0>{}, // 4+: GemmN10
Sequence<0, 0, 0, 0, 0>{}), // 5+: GemmN11
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmM0
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM10
Sequence<0, 0, 0, 0, 0>{}, // 2-: GemmM11
Sequence<0, 0, 0, 0, 0>{}, // 3-: GemmN0
Sequence<0, 0, 0, 0, 0>{}, // 4-: GemmN10
Sequence<0, 0, 0, 0, 0>{})); // 5-: GemmN11
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{};
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_dynamic_gemm_v1r3<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperation::Set,
decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc),
GemmMPerBlockM1,
GemmNPerBlockN1,
GemmKPerBlock,
GemmM1PerThreadM111,
GemmN1PerThreadN111,
GemmKPerThread,
GemmM11N11ThreadClusterM1100,
GemmM11N11ThreadClusterN1100,
GemmM11N11ThreadClusterM1101,
GemmM11N11ThreadClusterN1101,
GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1,
GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1,
Sequence<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder
Sequence<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder
GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
Sequence<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder
GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1,
GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1,
Sequence<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder
Sequence<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder
GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
Sequence<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder
GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
Sequence<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder
5, // CThreadTransferSrcDstVectorDim
GemmCThreadTransferDstScalarPerVector_N11,
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks),
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks),
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks),
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks),
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks)>(
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
in_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc,
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks,
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks,
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks,
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks,
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks,
nrepeat);
{
const auto N = out_n_ho_wo_k_lengths[I0];
const auto K = out_n_ho_wo_k_lengths[I3];
const auto C = wei_k_y_x_c_lengths[I3];
const auto Hi = in_n_hi_wi_c_lengths[I1];
const auto Wi = in_n_hi_wi_c_lengths[I2];
const auto Ho = out_n_ho_wo_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2];
const auto Y = wei_k_y_x_c_lengths[I1];
const auto X = wei_k_y_x_c_lengths[I2];
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
// copy result back to host
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
}
...@@ -275,12 +275,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh ...@@ -275,12 +275,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
Sequence<0, 0, 0, 0, 0>{}, // 6-: M2 Sequence<0, 0, 0, 0, 0>{}, // 6-: M2
Sequence<0, 0, 0, 0, 0>{})); // 7-: N1 Sequence<0, 0, 0, 0, 0>{})); // 7-: N1
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0>{};
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0>{};
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
float ave_time = driver_dynamic_gemm_xdlops_v2r3< float ave_time = driver_dynamic_gemm_xdlops_v2r3<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment