Commit 3f976dd0 authored by Rosty Geyyer's avatar Rosty Geyyer
Browse files

Update batch handling

parent b9f23971
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp" #include <algorithm>
#include <iostream> #include <iostream>
#include <numeric> #include <iterator>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...@@ -37,7 +36,7 @@ static constexpr auto ConvBwdWeightDefault = ...@@ -37,7 +36,7 @@ static constexpr auto ConvBwdWeightDefault =
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance = using DeviceConvBwdWeightInstance =
ck::tensor_operation::device::DeviceConvNdBwdWeightNwcKxcNwk_Dl< ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl<
NDimSpatial, // NDimSpatial NDimSpatial, // NDimSpatial
InDataType, // InDataType InDataType, // InDataType
WeiDataType, // WeiDataType WeiDataType, // WeiDataType
...@@ -142,7 +141,7 @@ int run_conv_bwd_weight(bool do_verification, ...@@ -142,7 +141,7 @@ int run_conv_bwd_weight(bool do_verification,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{}; std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{}; std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{}; std::array<ck::index_t, NDimSpatial> input_right_pads{};
auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); }; auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); };
range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths)); range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths));
...@@ -295,16 +294,16 @@ int main(int argc, char* argv[]) ...@@ -295,16 +294,16 @@ int main(int argc, char* argv[])
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
DeviceConvBwdWeightInstance<1>>(do_verification, DeviceConvBwdWeightInstance<1>>(do_verification,
init_method, init_method,
time_kernel, time_kernel,
conv_param, conv_param,
in_g_n_c_wis_desc, in_g_n_c_wis_desc,
wei_g_k_c_xs_desc, wei_g_k_c_xs_desc,
out_g_n_k_wos_desc, out_g_n_k_wos_desc,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op, out_element_op,
split_k); split_k);
} }
else if(conv_param.num_dim_spatial_ == 2) else if(conv_param.num_dim_spatial_ == 2)
{ {
...@@ -332,16 +331,16 @@ int main(int argc, char* argv[]) ...@@ -332,16 +331,16 @@ int main(int argc, char* argv[])
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
DeviceConvBwdWeightInstance<2>>(do_verification, DeviceConvBwdWeightInstance<2>>(do_verification,
init_method, init_method,
time_kernel, time_kernel,
conv_param, conv_param,
in_g_n_c_wis_desc, in_g_n_c_wis_desc,
wei_g_k_c_xs_desc, wei_g_k_c_xs_desc,
out_g_n_k_wos_desc, out_g_n_k_wos_desc,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op, out_element_op,
split_k); split_k);
} }
else if(conv_param.num_dim_spatial_ == 3) else if(conv_param.num_dim_spatial_ == 3)
{ {
...@@ -369,16 +368,16 @@ int main(int argc, char* argv[]) ...@@ -369,16 +368,16 @@ int main(int argc, char* argv[])
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
DeviceConvBwdWeightInstance<3>>(do_verification, DeviceConvBwdWeightInstance<3>>(do_verification,
init_method, init_method,
time_kernel, time_kernel,
conv_param, conv_param,
in_g_n_c_wis_desc, in_g_n_c_wis_desc,
wei_g_k_c_xs_desc, wei_g_k_c_xs_desc,
out_g_n_k_wos_desc, out_g_n_k_wos_desc,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op, out_element_op,
split_k); split_k);
} }
return 0; return 0;
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <numeric>
#include <sstream> #include <sstream>
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
...@@ -20,6 +21,83 @@ namespace ck { ...@@ -20,6 +21,83 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace {
struct ComputePtrOffsetOfStridedBatch
{
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_);
}
index_t BatchStrideA_;
index_t BatchStrideB_;
index_t BatchStrideC_;
};
} // namespace
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M0_M1_K1,
typename BGridDesc_K0_N0_N1_K1,
typename CGridDesc_M0_M10_M11_N0_N10_N11,
typename DefaultBlock2CTileMap,
typename ComputePtrOffsetOfBatch,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_dlops_bwd_weight(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const index_t batch_count,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const DefaultBlock2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)];
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_grid_desc_kbatch_k0_m0_m1_k1,
b_grid_desc_kbatch_k0_n0_n1_k1,
c_grid_desc_m0_m10_m11_n0_n10_n11,
block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
...@@ -56,7 +134,7 @@ template <ck::index_t NDimSpatial, ...@@ -56,7 +134,7 @@ template <ck::index_t NDimSpatial,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector> index_t CThreadTransferDstScalarPerVector>
struct DeviceConvNdBwdWeightNwcKxcNwk_Dl struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
: public DeviceGroupedConvBwdWeight< : public DeviceGroupedConvBwdWeight<
NDimSpatial, NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1, ck::tuple_element_t<NDimSpatial - 1,
...@@ -78,7 +156,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -78,7 +156,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation> OutElementwiseOperation>
{ {
using DeviceOp = DeviceConvNdBwdWeightNwcKxcNwk_Dl; using DeviceOp = DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl;
using ADataType = OutDataType; using ADataType = OutDataType;
using BDataType = InDataType; using BDataType = InDataType;
...@@ -116,18 +194,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -116,18 +194,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
static constexpr auto BBlockLdsN1Padding = 4; static constexpr auto BBlockLdsN1Padding = 4;
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads, std::array<ck::index_t, NDimSpatial> input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads, std::array<ck::index_t, NDimSpatial> input_right_pads,
ck::index_t batch_k) ck::index_t batch_k)
{ {
using namespace ck; using namespace ck;
...@@ -255,7 +333,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -255,7 +333,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)), make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor // C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc =
...@@ -263,23 +341,23 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -263,23 +341,23 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmm_gemmn_grid_desc);
} }
} // function end } // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads, std::array<ck::index_t, NDimSpatial> input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads, std::array<ck::index_t, NDimSpatial> input_right_pads,
ck::index_t batch_k) ck::index_t batch_k)
{ {
using namespace ck; using namespace ck;
...@@ -436,18 +514,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -436,18 +514,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
} // function end } // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads, std::array<ck::index_t, NDimSpatial> input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads, std::array<ck::index_t, NDimSpatial> input_right_pads,
ck::index_t batch_k) ck::index_t batch_k)
{ {
using namespace ck; using namespace ck;
...@@ -727,6 +805,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -727,6 +805,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
a_grid_desc_kbatch_k0_m_k1_{}, a_grid_desc_kbatch_k0_m_k1_{},
b_grid_desc_kbatch_k0_n_k1_{}, b_grid_desc_kbatch_k0_n_k1_{},
c_grid_desc_m_n_{}, c_grid_desc_m_n_{},
block_2_ctile_map_{},
compute_ptr_offset_of_batch_{},
a_element_op_{out_element_op}, a_element_op_{out_element_op},
b_element_op_{wei_element_op}, b_element_op_{wei_element_op},
c_element_op_{in_element_op}, c_element_op_{in_element_op},
...@@ -759,12 +839,35 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -759,12 +839,35 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; a_grid_desc_kbatch_k0_m_k1_ = descs[I0];
b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2]; c_grid_desc_m_n_ = descs[I2];
a_grid_desc_kbatch_k0_m0_m1_k1_ = GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1_); a_grid_desc_kbatch_k0_m0_m1_k1_ =
b_grid_desc_kbatch_k0_n0_n1_k1_ = GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_kbatch_k0_n_k1_); GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1_);
c_grid_desc_m0_m10_m11_n0_n10_n11_ = GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_); b_grid_desc_kbatch_k0_n0_n1_k1_ =
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_kbatch_k0_n_k1_);
c_grid_desc_m0_m10_m11_n0_n10_n11_ =
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_); block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
// A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ =
N * K *
std::accumulate(begin(output_spatial_lengths),
end(output_spatial_lengths),
index_t{1},
std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideB_ =
N * C *
std::accumulate(begin(input_spatial_lengths),
end(input_spatial_lengths),
index_t{1},
std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideC_ =
K * C *
std::accumulate(begin(filter_spatial_lengths),
end(filter_spatial_lengths),
index_t{1},
std::multiplies<>{});
} }
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
...@@ -781,11 +884,14 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -781,11 +884,14 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
DefaultBlock2CTileMap block_2_ctile_map_; DefaultBlock2CTileMap block_2_ctile_map_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
// element-wise op // element-wise op
OutElementwiseOperation a_element_op_; OutElementwiseOperation a_element_op_;
WeiElementwiseOperation b_element_op_; WeiElementwiseOperation b_element_op_;
InElementwiseOperation c_element_op_; InElementwiseOperation c_element_op_;
// for checking IsSupportedArgument() // for checking IsSupportedArgument()
index_t Conv_G_; index_t Conv_G_;
index_t Conv_N_; index_t Conv_N_;
...@@ -813,25 +919,21 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -813,25 +919,21 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
<< std::endl;
std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
<< std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I0) << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;
} }
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
ShowInfo(arg); ShowInfo(arg);
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
...@@ -845,59 +947,63 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -845,59 +947,63 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Conv_G_; arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Conv_G_;
auto launch_kernel = [&](auto has_main_k_block_loop, auto launch_kernel = [&](auto has_main_k_block_loop,
auto has_double_tail_k_block_loop) { auto has_double_tail_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool has_main_loop = has_main_k_block_loop.value;
constexpr bool has_double_loop = has_double_tail_k_block_loop; constexpr bool has_double_loop = has_double_tail_k_block_loop;
const auto kernel = kernel_gemm_dl_v1r3< const auto kernel = kernel_batched_gemm_dlops_bwd_weight<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M0_M1_K1>, remove_reference_t<DeviceOp::AGridDesc_K0_M0_M1_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N0_N1_K1>, remove_reference_t<DeviceOp::BGridDesc_K0_N0_N1_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11>, remove_reference_t<DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DeviceOp::DefaultBlock2CTileMap>, remove_reference_t<DeviceOp::DefaultBlock2CTileMap>,
has_main_loop, ComputePtrOffsetOfStridedBatch,
has_double_loop>; has_main_loop,
has_double_loop>;
return launch_and_time_kernel(stream_config,
kernel, return launch_and_time_kernel(stream_config,
dim3(grid_size), kernel,
dim3(BlockSize), dim3(grid_size),
0, dim3(BlockSize),
arg.p_a_grid_, 0,
arg.p_b_grid_, arg.p_a_grid_,
arg.p_c_grid_, arg.p_b_grid_,
arg.a_grid_desc_kbatch_k0_m0_m1_k1_, arg.p_c_grid_,
arg.b_grid_desc_kbatch_k0_n0_n1_k1_, arg.Conv_G_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, arg.a_grid_desc_kbatch_k0_m0_m1_k1_,
arg.block_2_ctile_map_); arg.b_grid_desc_kbatch_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_,
arg.compute_ptr_offset_of_batch_);
}; };
const auto K0 = arg.a_grid_desc_kbatch_k0_m0_m1_k1_.GetLength(I1); const auto K0 = arg.a_grid_desc_kbatch_k0_m0_m1_k1_.GetLength(I1);
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop = const bool has_double_tail_k_block_loop =
GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
return launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, true>{}); return launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
return launch_kernel(integral_constant<bool, true>{}, return launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
return launch_kernel(integral_constant<bool, false>{}, return launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{}); integral_constant<bool, true>{});
} }
else else
{ {
return launch_kernel(integral_constant<bool, false>{}, return launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
} }
} }
...@@ -987,9 +1093,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -987,9 +1093,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
} }
// Gridwise GEMM size // Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, return GridwiseGemm::CheckValidity(
arg.b_grid_desc_kbatch_k0_n_k1_, arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_m_n_);
arg.c_grid_desc_m_n_);
} }
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
...@@ -1088,7 +1193,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -1088,7 +1193,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceConvNdBwdWeightNwcKxcNwk_Dl" str << "DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment