Commit 581d244c authored by Rosty Geyyer's avatar Rosty Geyyer
Browse files

Add gridwise gemm supporting batched input

parent a768dea5
......@@ -56,20 +56,20 @@ using DeviceConvBwdWeightInstance =
1, // KPerThread
S<8, 2>, // M1N1ThreadClusterM1Xs
S<8, 2>, // M1N1ThreadClusterN1Xs
S<8, 1, 1, 2>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1
S<2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1
S<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder
S<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder
S<4, 1, 1, 2>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
S<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder
S<1, 1, 1, 2>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
S<1, 1, 8, 2>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1
S<16, 1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<0, 3, 1, 2>, // BBlockTransferSrcAccessOrder
S<1, 1, 8, 1>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
S<0, 3, 1, 2>, // BBlockTransferSrcVectorTensorContiguousDimOrder
S<1, 1, 1, 2>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
S<1, 8, 1, 1, 2>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1
S<1, 2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1
S<0, 2, 3, 1, 4>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 3, 1, 4>, // ABlockTransferSrcAccessOrder
S<1, 4, 1, 1, 2>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
S<0, 2, 3, 1, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder
S<1, 1, 1, 1, 2>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
S<1, 1, 1, 8, 2>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1
S<1, 16, 1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1
S<0, 1, 4, 2, 3>, // BBlockTransferThreadClusterArrangeOrder
S<0, 1, 4, 2, 3>, // BBlockTransferSrcAccessOrder
S<1, 1, 1, 8, 1>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
S<0, 1, 4, 2, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder
S<1, 1, 1, 1, 2>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder
5, // CThreadTransferSrcDstVectorDim
4>; // CThreadTransferDstScalarPerVector
......
......@@ -50,10 +50,10 @@ struct ComputePtrOffsetOfStridedBatch
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M0_M1_K1,
typename BGridDesc_K0_N0_N1_K1,
typename AGridDesc_B_K0_M0_M1_K1,
typename BGridDesc_B_K0_N0_N1_K1,
typename CGridDesc_M0_M10_M11_N0_N10_N11,
typename DefaultBlock2CTileMap,
typename Block2CTileMap,
typename ComputePtrOffsetOfBatch,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
......@@ -66,10 +66,10 @@ __global__ void
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const index_t batch_count,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1,
const AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1,
const BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const DefaultBlock2CTileMap block_2_ctile_map,
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
const index_t num_blocks_per_batch =
......@@ -85,7 +85,7 @@ __global__ void
__shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)];
GridwiseGemm::template Run<HasMainKBlockLoop>(
GridwiseGemm::template Run<HasMainKBlockLoop, HasDoubleTailKBlockLoop>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset,
......@@ -729,55 +729,55 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
using AGridDesc_B_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc_B_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
using GridwiseGemm =
GridwiseGemmDl_km_kn_mn_v1r3<BlockSize,
ADataType,
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
GridwiseGemmDl_bkm_bkn_mn_v1r3<BlockSize,
ADataType,
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_B_K0_M_K1,
BGridDesc_B_K0_N_K1,
CGridDesc_M_N,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
// Argument
using AGridDesc_K0_M0_M1_K1 =
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 =
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using AGridDesc_B_K0_M0_M1_K1 =
decltype(GridwiseGemm::MakeAGridDescriptor_B_K0_M0_M1_K1(AGridDesc_B_K0_M_K1{}));
using BGridDesc_B_K0_N0_N1_K1 =
decltype(GridwiseGemm::MakeBGridDescriptor_B_K0_N0_N1_K1(BGridDesc_B_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
using DefaultBlock2CTileMap =
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
using Block2CTileMap =
decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1));
struct Argument : public BaseArgument
{
......@@ -842,12 +842,15 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
c_grid_desc_m_n_ = descs[I2];
a_grid_desc_kbatch_k0_m0_m1_k1_ =
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1_);
GridwiseGemm::MakeAGridDescriptor_B_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1_);
b_grid_desc_kbatch_k0_n0_n1_k1_ =
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_kbatch_k0_n_k1_);
GridwiseGemm::MakeBGridDescriptor_B_K0_N0_N1_K1(b_grid_desc_kbatch_k0_n_k1_);
c_grid_desc_m0_m10_m11_n0_n10_n11_ =
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
ck::index_t M01 = 1;
ck::index_t N01 = 1;
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
// A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ =
......@@ -874,15 +877,16 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
AGridDesc_B_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_;
BGridDesc_B_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
AGridDesc_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1_;
BGridDesc_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1_;
AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1_;
BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1_;
CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_;
DefaultBlock2CTileMap block_2_ctile_map_;
// DefaultBlock2CTileMap block_2_ctile_map_;
Block2CTileMap block_2_ctile_map_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
......@@ -941,7 +945,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
arg.c_grid_desc_m_n_))
{
throw std::runtime_error(
"wrong! GridwiseGemm GridwiseGemmDl_km_kn_mn_v1r3 has invalid setting");
"wrong! GridwiseGemm GridwiseGemmDl_bkm_bkn_mn_v1r3 has invalid setting");
}
const index_t grid_size =
......@@ -950,16 +954,16 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
auto launch_kernel = [&](auto has_main_k_block_loop,
auto has_double_tail_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
constexpr bool has_double_loop = has_double_tail_k_block_loop;
constexpr bool has_double_loop = has_double_tail_k_block_loop.value;
const auto kernel = kernel_batched_gemm_dlops_bwd_weight<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M0_M1_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N0_N1_K1>,
remove_reference_t<DeviceOp::AGridDesc_B_K0_M0_M1_K1>,
remove_reference_t<DeviceOp::BGridDesc_B_K0_N0_N1_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DeviceOp::DefaultBlock2CTileMap>,
remove_reference_t<DeviceOp::Block2CTileMap>,
ComputePtrOffsetOfStridedBatch,
has_main_loop,
has_double_loop>;
......@@ -1045,18 +1049,18 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
// matrix A
{
auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{};
if(srcVectorLengths[I1] != 1 || srcVectorLengths[I2] != 1)
if(srcVectorLengths[I2] != 1 || srcVectorLengths[I3] != 1)
{
return false;
}
if(K1 % srcVectorLengths[I3] != 0 || K0PerBlock % srcVectorLengths[I0] != 0)
if(K1 % srcVectorLengths[I4] != 0 || K0PerBlock % srcVectorLengths[I1] != 0)
{
return false;
}
const index_t K = arg.Conv_K_;
if(K % (srcVectorLengths[I0] * srcVectorLengths[I3]) != 0)
if(K % (srcVectorLengths[I1] * srcVectorLengths[I4]) != 0)
{
return false;
}
......@@ -1066,19 +1070,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
{
auto srcLoadLenghts = BBlockTransferThreadSliceLengths_K0_N0_N1_K1{};
auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{};
if(srcVectorLengths[I0] != 1 || srcVectorLengths[I3] != 1)
if(srcVectorLengths[I1] != 1 || srcVectorLengths[I4] != 1)
{
return false;
}
if(srcLoadLenghts[I1] % srcVectorLengths[I1] != 0 ||
srcLoadLenghts[I2] % srcVectorLengths[I2] != 0)
if(srcLoadLenghts[I2] % srcVectorLengths[I2] != 0 ||
srcLoadLenghts[I3] % srcVectorLengths[I3] != 0)
{
return false;
}
const index_t C = arg.Conv_K_;
if(C % (srcVectorLengths[I1] * srcVectorLengths[I2]) != 0)
if(C % (srcVectorLengths[I2] * srcVectorLengths[I3]) != 0)
{
return false;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment