Commit 581d244c authored by Rosty Geyyer's avatar Rosty Geyyer
Browse files

Add gridwise gemm supporting batched input

parent a768dea5
...@@ -56,20 +56,20 @@ using DeviceConvBwdWeightInstance = ...@@ -56,20 +56,20 @@ using DeviceConvBwdWeightInstance =
1, // KPerThread 1, // KPerThread
S<8, 2>, // M1N1ThreadClusterM1Xs S<8, 2>, // M1N1ThreadClusterM1Xs
S<8, 2>, // M1N1ThreadClusterN1Xs S<8, 2>, // M1N1ThreadClusterN1Xs
S<8, 1, 1, 2>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1 S<1, 8, 1, 1, 2>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1
S<2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1 S<1, 2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1
S<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder S<0, 2, 3, 1, 4>, // ABlockTransferThreadClusterArrangeOrder
S<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder S<0, 2, 3, 1, 4>, // ABlockTransferSrcAccessOrder
S<4, 1, 1, 2>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 S<1, 4, 1, 1, 2>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
S<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder S<0, 2, 3, 1, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder
S<1, 1, 1, 2>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 S<1, 1, 1, 1, 2>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
S<1, 1, 8, 2>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1 S<1, 1, 1, 8, 2>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1
S<16, 1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1 S<1, 16, 1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder S<0, 1, 4, 2, 3>, // BBlockTransferThreadClusterArrangeOrder
S<0, 3, 1, 2>, // BBlockTransferSrcAccessOrder S<0, 1, 4, 2, 3>, // BBlockTransferSrcAccessOrder
S<1, 1, 8, 1>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 S<1, 1, 1, 8, 1>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
S<0, 3, 1, 2>, // BBlockTransferSrcVectorTensorContiguousDimOrder S<0, 1, 4, 2, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder
S<1, 1, 1, 2>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 S<1, 1, 1, 1, 2>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder
5, // CThreadTransferSrcDstVectorDim 5, // CThreadTransferSrcDstVectorDim
4>; // CThreadTransferDstScalarPerVector 4>; // CThreadTransferDstScalarPerVector
......
...@@ -50,10 +50,10 @@ struct ComputePtrOffsetOfStridedBatch ...@@ -50,10 +50,10 @@ struct ComputePtrOffsetOfStridedBatch
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename AGridDesc_K0_M0_M1_K1, typename AGridDesc_B_K0_M0_M1_K1,
typename BGridDesc_K0_N0_N1_K1, typename BGridDesc_B_K0_N0_N1_K1,
typename CGridDesc_M0_M10_M11_N0_N10_N11, typename CGridDesc_M0_M10_M11_N0_N10_N11,
typename DefaultBlock2CTileMap, typename Block2CTileMap,
typename ComputePtrOffsetOfBatch, typename ComputePtrOffsetOfBatch,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
...@@ -66,10 +66,10 @@ __global__ void ...@@ -66,10 +66,10 @@ __global__ void
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const index_t batch_count, const index_t batch_count,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1, const AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1, const BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const DefaultBlock2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
...@@ -85,7 +85,7 @@ __global__ void ...@@ -85,7 +85,7 @@ __global__ void
__shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)]; __shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)];
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop, HasDoubleTailKBlockLoop>(
p_a_grid + a_batch_offset, p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
...@@ -729,55 +729,55 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -729,55 +729,55 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>()); using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>; using AGridDesc_B_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>; using BGridDesc_B_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>; using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
using GridwiseGemm = using GridwiseGemm =
GridwiseGemmDl_km_kn_mn_v1r3<BlockSize, GridwiseGemmDl_bkm_bkn_mn_v1r3<BlockSize,
ADataType, ADataType,
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_B_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_B_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
K0PerBlock, K0PerBlock,
K1, K1,
M1PerThread, M1PerThread,
N1PerThread, N1PerThread,
KPerThread, KPerThread,
M1N1ThreadClusterM1Xs, M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs, M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>; CThreadTransferDstScalarPerVector>;
// Argument // Argument
using AGridDesc_K0_M0_M1_K1 = using AGridDesc_B_K0_M0_M1_K1 =
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); decltype(GridwiseGemm::MakeAGridDescriptor_B_K0_M0_M1_K1(AGridDesc_B_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 = using BGridDesc_B_K0_N0_N1_K1 =
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); decltype(GridwiseGemm::MakeBGridDescriptor_B_K0_N0_N1_K1(BGridDesc_B_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 = using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
using DefaultBlock2CTileMap = using Block2CTileMap =
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{})); decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1));
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -842,12 +842,15 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -842,12 +842,15 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
c_grid_desc_m_n_ = descs[I2]; c_grid_desc_m_n_ = descs[I2];
a_grid_desc_kbatch_k0_m0_m1_k1_ = a_grid_desc_kbatch_k0_m0_m1_k1_ =
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1_); GridwiseGemm::MakeAGridDescriptor_B_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1_);
b_grid_desc_kbatch_k0_n0_n1_k1_ = b_grid_desc_kbatch_k0_n0_n1_k1_ =
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_kbatch_k0_n_k1_); GridwiseGemm::MakeBGridDescriptor_B_K0_N0_N1_K1(b_grid_desc_kbatch_k0_n_k1_);
c_grid_desc_m0_m10_m11_n0_n10_n11_ = c_grid_desc_m0_m10_m11_n0_n10_n11_ =
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_); ck::index_t M01 = 1;
ck::index_t N01 = 1;
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
// A/B/C Batch Stride // A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = compute_ptr_offset_of_batch_.BatchStrideA_ =
...@@ -874,15 +877,16 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -874,15 +877,16 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; AGridDesc_B_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; BGridDesc_B_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
AGridDesc_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1_; AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1_;
BGridDesc_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1_; BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1_;
CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_; CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_;
DefaultBlock2CTileMap block_2_ctile_map_; // DefaultBlock2CTileMap block_2_ctile_map_;
Block2CTileMap block_2_ctile_map_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
...@@ -941,7 +945,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -941,7 +945,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
arg.c_grid_desc_m_n_)) arg.c_grid_desc_m_n_))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm GridwiseGemmDl_km_kn_mn_v1r3 has invalid setting"); "wrong! GridwiseGemm GridwiseGemmDl_bkm_bkn_mn_v1r3 has invalid setting");
} }
const index_t grid_size = const index_t grid_size =
...@@ -950,16 +954,16 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -950,16 +954,16 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
auto launch_kernel = [&](auto has_main_k_block_loop, auto launch_kernel = [&](auto has_main_k_block_loop,
auto has_double_tail_k_block_loop) { auto has_double_tail_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool has_main_loop = has_main_k_block_loop.value;
constexpr bool has_double_loop = has_double_tail_k_block_loop; constexpr bool has_double_loop = has_double_tail_k_block_loop.value;
const auto kernel = kernel_batched_gemm_dlops_bwd_weight< const auto kernel = kernel_batched_gemm_dlops_bwd_weight<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M0_M1_K1>, remove_reference_t<DeviceOp::AGridDesc_B_K0_M0_M1_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N0_N1_K1>, remove_reference_t<DeviceOp::BGridDesc_B_K0_N0_N1_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11>, remove_reference_t<DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DeviceOp::DefaultBlock2CTileMap>, remove_reference_t<DeviceOp::Block2CTileMap>,
ComputePtrOffsetOfStridedBatch, ComputePtrOffsetOfStridedBatch,
has_main_loop, has_main_loop,
has_double_loop>; has_double_loop>;
...@@ -1045,18 +1049,18 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1045,18 +1049,18 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
// matrix A // matrix A
{ {
auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{}; auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{};
if(srcVectorLengths[I1] != 1 || srcVectorLengths[I2] != 1) if(srcVectorLengths[I2] != 1 || srcVectorLengths[I3] != 1)
{ {
return false; return false;
} }
if(K1 % srcVectorLengths[I3] != 0 || K0PerBlock % srcVectorLengths[I0] != 0) if(K1 % srcVectorLengths[I4] != 0 || K0PerBlock % srcVectorLengths[I1] != 0)
{ {
return false; return false;
} }
const index_t K = arg.Conv_K_; const index_t K = arg.Conv_K_;
if(K % (srcVectorLengths[I0] * srcVectorLengths[I3]) != 0) if(K % (srcVectorLengths[I1] * srcVectorLengths[I4]) != 0)
{ {
return false; return false;
} }
...@@ -1066,19 +1070,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1066,19 +1070,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
{ {
auto srcLoadLenghts = BBlockTransferThreadSliceLengths_K0_N0_N1_K1{}; auto srcLoadLenghts = BBlockTransferThreadSliceLengths_K0_N0_N1_K1{};
auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{}; auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{};
if(srcVectorLengths[I0] != 1 || srcVectorLengths[I3] != 1) if(srcVectorLengths[I1] != 1 || srcVectorLengths[I4] != 1)
{ {
return false; return false;
} }
if(srcLoadLenghts[I1] % srcVectorLengths[I1] != 0 || if(srcLoadLenghts[I2] % srcVectorLengths[I2] != 0 ||
srcLoadLenghts[I2] % srcVectorLengths[I2] != 0) srcLoadLenghts[I3] % srcVectorLengths[I3] != 0)
{ {
return false; return false;
} }
const index_t C = arg.Conv_K_; const index_t C = arg.Conv_K_;
if(C % (srcVectorLengths[I1] * srcVectorLengths[I2]) != 0) if(C % (srcVectorLengths[I2] * srcVectorLengths[I3]) != 0)
{ {
return false; return false;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment