Commit 3f976dd0 authored by Rosty Geyyer's avatar Rosty Geyyer
Browse files

Update batch handling

parent b9f23971
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp" #include <algorithm>
#include <iostream> #include <iostream>
#include <numeric> #include <iterator>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...@@ -37,7 +36,7 @@ static constexpr auto ConvBwdWeightDefault = ...@@ -37,7 +36,7 @@ static constexpr auto ConvBwdWeightDefault =
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance = using DeviceConvBwdWeightInstance =
ck::tensor_operation::device::DeviceConvNdBwdWeightNwcKxcNwk_Dl< ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl<
NDimSpatial, // NDimSpatial NDimSpatial, // NDimSpatial
InDataType, // InDataType InDataType, // InDataType
WeiDataType, // WeiDataType WeiDataType, // WeiDataType
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <numeric>
#include <sstream> #include <sstream>
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
...@@ -20,6 +21,83 @@ namespace ck { ...@@ -20,6 +21,83 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace {
struct ComputePtrOffsetOfStridedBatch
{
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_);
}
index_t BatchStrideA_;
index_t BatchStrideB_;
index_t BatchStrideC_;
};
} // namespace
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M0_M1_K1,
typename BGridDesc_K0_N0_N1_K1,
typename CGridDesc_M0_M10_M11_N0_N10_N11,
typename DefaultBlock2CTileMap,
typename ComputePtrOffsetOfBatch,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_dlops_bwd_weight(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const index_t batch_count,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const DefaultBlock2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)];
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_grid_desc_kbatch_k0_m0_m1_k1,
b_grid_desc_kbatch_k0_n0_n1_k1,
c_grid_desc_m0_m10_m11_n0_n10_n11,
block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
...@@ -56,7 +134,7 @@ template <ck::index_t NDimSpatial, ...@@ -56,7 +134,7 @@ template <ck::index_t NDimSpatial,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector> index_t CThreadTransferDstScalarPerVector>
struct DeviceConvNdBwdWeightNwcKxcNwk_Dl struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
: public DeviceGroupedConvBwdWeight< : public DeviceGroupedConvBwdWeight<
NDimSpatial, NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1, ck::tuple_element_t<NDimSpatial - 1,
...@@ -78,7 +156,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -78,7 +156,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation> OutElementwiseOperation>
{ {
using DeviceOp = DeviceConvNdBwdWeightNwcKxcNwk_Dl; using DeviceOp = DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl;
using ADataType = OutDataType; using ADataType = OutDataType;
using BDataType = InDataType; using BDataType = InDataType;
...@@ -116,8 +194,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -116,8 +194,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
static constexpr auto BBlockLdsN1Padding = 4; static constexpr auto BBlockLdsN1Padding = 4;
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
...@@ -268,8 +346,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -268,8 +346,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
} // function end } // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
...@@ -436,8 +514,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -436,8 +514,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
} // function end } // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
...@@ -727,6 +805,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -727,6 +805,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
a_grid_desc_kbatch_k0_m_k1_{}, a_grid_desc_kbatch_k0_m_k1_{},
b_grid_desc_kbatch_k0_n_k1_{}, b_grid_desc_kbatch_k0_n_k1_{},
c_grid_desc_m_n_{}, c_grid_desc_m_n_{},
block_2_ctile_map_{},
compute_ptr_offset_of_batch_{},
a_element_op_{out_element_op}, a_element_op_{out_element_op},
b_element_op_{wei_element_op}, b_element_op_{wei_element_op},
c_element_op_{in_element_op}, c_element_op_{in_element_op},
...@@ -761,10 +841,33 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -761,10 +841,33 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2]; c_grid_desc_m_n_ = descs[I2];
a_grid_desc_kbatch_k0_m0_m1_k1_ = GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1_); a_grid_desc_kbatch_k0_m0_m1_k1_ =
b_grid_desc_kbatch_k0_n0_n1_k1_ = GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_kbatch_k0_n_k1_); GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1_);
c_grid_desc_m0_m10_m11_n0_n10_n11_ = GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_); b_grid_desc_kbatch_k0_n0_n1_k1_ =
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_kbatch_k0_n_k1_);
c_grid_desc_m0_m10_m11_n0_n10_n11_ =
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_); block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
// A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ =
N * K *
std::accumulate(begin(output_spatial_lengths),
end(output_spatial_lengths),
index_t{1},
std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideB_ =
N * C *
std::accumulate(begin(input_spatial_lengths),
end(input_spatial_lengths),
index_t{1},
std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideC_ =
K * C *
std::accumulate(begin(filter_spatial_lengths),
end(filter_spatial_lengths),
index_t{1},
std::multiplies<>{});
} }
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
...@@ -781,6 +884,9 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -781,6 +884,9 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
DefaultBlock2CTileMap block_2_ctile_map_; DefaultBlock2CTileMap block_2_ctile_map_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
// element-wise op // element-wise op
OutElementwiseOperation a_element_op_; OutElementwiseOperation a_element_op_;
WeiElementwiseOperation b_element_op_; WeiElementwiseOperation b_element_op_;
...@@ -813,20 +919,16 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -813,20 +919,16 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
<< std::endl;
std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
<< std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I0) << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;
} }
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
...@@ -850,7 +952,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -850,7 +952,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool has_main_loop = has_main_k_block_loop.value;
constexpr bool has_double_loop = has_double_tail_k_block_loop; constexpr bool has_double_loop = has_double_tail_k_block_loop;
const auto kernel = kernel_gemm_dl_v1r3< const auto kernel = kernel_batched_gemm_dlops_bwd_weight<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
...@@ -858,6 +960,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -858,6 +960,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
remove_reference_t<DeviceOp::BGridDesc_K0_N0_N1_K1>, remove_reference_t<DeviceOp::BGridDesc_K0_N0_N1_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11>, remove_reference_t<DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DeviceOp::DefaultBlock2CTileMap>, remove_reference_t<DeviceOp::DefaultBlock2CTileMap>,
ComputePtrOffsetOfStridedBatch,
has_main_loop, has_main_loop,
has_double_loop>; has_double_loop>;
...@@ -869,10 +972,12 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -869,10 +972,12 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.Conv_G_,
arg.a_grid_desc_kbatch_k0_m0_m1_k1_, arg.a_grid_desc_kbatch_k0_m0_m1_k1_,
arg.b_grid_desc_kbatch_k0_n0_n1_k1_, arg.b_grid_desc_kbatch_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_,
arg.compute_ptr_offset_of_batch_);
}; };
const auto K0 = arg.a_grid_desc_kbatch_k0_m0_m1_k1_.GetLength(I1); const auto K0 = arg.a_grid_desc_kbatch_k0_m0_m1_k1_.GetLength(I1);
...@@ -882,7 +987,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -882,7 +987,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
return launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, true>{}); return launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
...@@ -987,9 +1093,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -987,9 +1093,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
} }
// Gridwise GEMM size // Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, return GridwiseGemm::CheckValidity(
arg.b_grid_desc_kbatch_k0_n_k1_, arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_m_n_);
arg.c_grid_desc_m_n_);
} }
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
...@@ -1088,7 +1193,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -1088,7 +1193,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceConvNdBwdWeightNwcKxcNwk_Dl" str << "DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment