Commit 0be1cf14 authored by Chao Liu's avatar Chao Liu
Browse files

update conv bwd weight

parent b054669b
......@@ -15,6 +15,19 @@ enum struct ConvolutionBackwardWeightSpecialization
OddC,
};
inline std::string
getConvBackwardWeightSpecializationString(const ConvolutionBackwardWeightSpecialization& s)
{
switch(s)
{
case ConvolutionBackwardWeightSpecialization::Default: return "Default";
case ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0:
return "Filter1x1Stride1Pad0";
case ConvolutionBackwardWeightSpecialization::Filter1x1Pad0: return "Filter1x1Pad0";
case ConvolutionBackwardWeightSpecialization::OddC: return "OddC";
default: return "Unrecognized specialization!";
}
}
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CONVOLUTION_FORWARD_SPECIALIZATION
#define CONVOLUTION_FORWARD_SPECIALIZATION
#pragma once
#include <string>
......@@ -33,4 +32,3 @@ inline std::string getConvForwardSpecializationString(const ConvolutionForwardSp
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -10,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp"
......@@ -57,7 +57,14 @@ template <typename InDataType,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvBwdWeight<InElementwiseOperation,
: public DeviceConvBwdWeight<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{
......
......@@ -55,7 +55,14 @@ template <typename InDataType,
ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector>
struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvBwdData<InElementwiseOperation,
: public DeviceConvBwdData<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{
......
......@@ -4,16 +4,21 @@
#pragma once
#include <vector>
#include <iostream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InElementwiseOperation,
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
struct DeviceConvBwdData : public BaseOperator
......@@ -39,12 +44,6 @@ struct DeviceConvBwdData : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
using DeviceConvBwdDataPtr = std::unique_ptr<
DeviceConvBwdData<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -4,7 +4,6 @@
#pragma once
#include <vector>
#include <iostream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
......@@ -12,7 +11,14 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <typename InElementwiseOperation,
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
struct DeviceConvBwdWeight : public BaseOperator
......@@ -39,12 +45,6 @@ struct DeviceConvBwdWeight : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
using DeviceConvBwdWeightPtr = std::unique_ptr<
DeviceConvBwdWeight<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -3,7 +3,6 @@
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
......@@ -12,7 +11,7 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <ck::index_t NumDimSpatial,
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
......
......@@ -21,7 +21,8 @@ namespace tensor_operation {
namespace device {
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template <typename InDataType,
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename AccDataType,
......@@ -29,7 +30,6 @@ template <typename InDataType,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ConvolutionBackwardDataSpecialization ConvBackwardDataSpecialization,
ck::index_t NumDimSpatial,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
......@@ -55,12 +55,29 @@ template <typename InDataType,
bool BBlockLdsAddExtraN,
ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector>
struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
: public DeviceConvBwdData<InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
: public DeviceConvBwdData<
NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::NDHWC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::KZYXC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWK,
ck::tensor_layout::convolution::NHWK,
ck::tensor_layout::convolution::NDHWK>>,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{
using DeviceOp = DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K;
using DeviceOp = DeviceConvNdBwdDataNwcKxcNwk_Xdl;
using ADataType = OutDataType;
using BDataType = WeiDataType;
......@@ -950,7 +967,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
{0, 0, 0});
}
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
......@@ -1037,7 +1054,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
CreateABCDesc<NumDimSpatial>();
CreateABCDesc<NDimSpatial>();
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
......@@ -1060,7 +1077,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
}
const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NumDimSpatial>(
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
......@@ -1118,7 +1135,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
}
const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NumDimSpatial>(
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
......@@ -1186,18 +1203,18 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
}
const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
NumDimSpatial>(Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_,
{i_ztilde, i_ytilde, i_xtilde});
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_,
{i_ztilde, i_ytilde, i_xtilde});
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
......@@ -1398,7 +1415,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 pad = 0 conv
for(int i = 0; i < NumDimSpatial; i++)
for(int i = 0; i < NDimSpatial; i++)
{
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
......@@ -1528,7 +1545,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
auto str = std::stringstream();
// clang-format off
str << "DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K"
str << "DeviceConvNdBwdDataNwcKxcNwk_Xdl"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -10,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp"
......@@ -22,7 +22,8 @@ namespace tensor_operation {
namespace device {
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template <typename InDataType,
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename AccDataType,
......@@ -30,7 +31,6 @@ template <typename InDataType,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ConvolutionBackwardWeightSpecialization ConvBackwardWeightSpecialization,
ck::index_t NumDimSpatial,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
......@@ -58,13 +58,29 @@ template <typename InDataType,
index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvBwdWeight<InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
: public DeviceConvBwdWeight<
NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::NDHWC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::KZYXC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWK,
ck::tensor_layout::convolution::NHWK,
ck::tensor_layout::convolution::NDHWK>>,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{
using DeviceOp =
DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
using DeviceOp = DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle;
using ADataType = OutDataType;
using BDataType = InDataType;
......@@ -675,125 +691,19 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
return PadDescriptor_M0_1d(desc, gridSize, blockSize);
}
using TypeConvertFp32ToBf16Functor =
ck::tensor_operation::element_wise::UnaryTypeConvert<ck::bhalf_t, float>;
using GridDesc_M0 = decltype(MakeDescriptor_M0<1>({1}, {1}, 1, 1));
using GridwiseUEltwise = GridwiseUnaryElementwise_1D<AccDataType,
InDataType,
GridDesc_M0,
TypeConvertFp32ToBf16Functor,
4>;
using GridDesc_M0 = decltype(MakeDescriptor_M0<1>({1}, {1}, 1, 1));
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
MPerBlock,
NPerBlock,
K0PerBlock,
MPerXdl,
NPerXdl,
K1,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
ABlockLdsM1PerBlock,
ABlockLdsM0PerBlock,
ABlockLdsM1Padding,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
BBlockLdsN1PerBlock,
BBlockLdsN0PerBlock,
BBlockLdsN1Padding,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXdl,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
true,
true>;
using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
MPerBlock,
NPerBlock,
K0PerBlock,
MPerXdl,
NPerXdl,
K1,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
ABlockLdsM1PerBlock,
ABlockLdsM0PerBlock,
ABlockLdsM1Padding,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
BBlockLdsN1PerBlock,
BBlockLdsN0PerBlock,
BBlockLdsN1Padding,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXdl,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
true,
true>;
using GridwiseGemmAtomicAddFloatBf16Splitk = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
AccDataType,
AccDataType,
InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
......@@ -890,7 +800,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
k_batch_{split_k}
{
const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NumDimSpatial>(
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
N,
K,
C,
......@@ -980,268 +890,55 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
}
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
float ave_time = 0;
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
const auto run_conv = [&](const auto& kernel) {
hipGetErrorString(hipMemset(
arg.p_c_grid_,
0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType)));
float elapsed_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
hipGetErrorString(hipMemset(
arg.p_c_grid_,
0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType)));
launch_and_time_kernel(StreamConfig{nullptr, false},
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
return elapsed_time;
};
// run kernel for bf16 with splitk
const auto run_bf16_splitk = [&](const auto& kernel) {
hipGetErrorString(hipMemset(
arg.p_workspace_,
0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(AccDataType)));
float elapsed_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
static_cast<AccDataType*>(arg.p_workspace_),
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
hipGetErrorString(hipMemset(
arg.p_workspace_,
0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(AccDataType)));
launch_and_time_kernel(StreamConfig{nullptr, false},
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
static_cast<AccDataType*>(arg.p_workspace_),
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
return elapsed_time;
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
OutElementwiseOperation,
InElementwiseOperation,
WeiElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
has_main_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
};
// kernel for type conversion
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(arg.Conv_K_),
static_cast<std::size_t>(arg.Conv_C_)};
filter_dims.insert(std::end(filter_dims),
std::begin(arg.filter_spatial_lengths_),
std::end(arg.filter_spatial_lengths_));
int tensor_size =
std::accumulate(filter_dims.begin(), filter_dims.end(), 1, std::multiplies<int>{});
const index_t type_convert_grid_size = GridwiseUEltwise::CalculateGridSize(tensor_size);
GridDesc_M0 a_grid_desc_m0_ =
MakeDescriptor_M0<1>({tensor_size}, {1}, type_convert_grid_size, 256);
GridDesc_M0 b_grid_desc_m0_ =
MakeDescriptor_M0<1>({tensor_size}, {1}, type_convert_grid_size, 256);
if(!GridwiseUEltwise::CheckValidity(a_grid_desc_m0_, b_grid_desc_m0_))
if(has_main_k0_block_loop)
{
throw std::runtime_error("wrong! GridwiseUnaryElementwise_1D has invalid setting");
}
// run kernel for type conversion
void* p_c_grid_tmp_ = static_cast<void*>(arg.p_c_grid_);
InDataType* p_c_grid_tmp_bf16_ = static_cast<InDataType*>(p_c_grid_tmp_);
const auto run_type_convert = [&](const auto& kernel) {
float elapsed_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(type_convert_grid_size),
dim3(256),
0,
static_cast<AccDataType*>(arg.p_workspace_),
p_c_grid_tmp_bf16_,
a_grid_desc_m0_,
b_grid_desc_m0_,
TypeConvertFp32ToBf16Functor{});
return elapsed_time;
};
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
{
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
if(kbatch == 1)
{
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
OutElementwiseOperation,
InElementwiseOperation,
WeiElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
has_main_loop>;
return run_conv(kernel);
}
else
{
const auto kernel_type_convert =
kernel_unary_elementwise_1d<GridwiseUEltwise,
AccDataType,
InDataType,
GridDesc_M0,
TypeConvertFp32ToBf16Functor>;
const auto kernel_conv = kernel_gemm_xdlops_bwd_weight<
GridwiseGemmAtomicAddFloatBf16Splitk,
ADataType, // TODO: distiguish A/B datatype
AccDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
OutElementwiseOperation,
InElementwiseOperation,
WeiElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
has_main_loop>;
float elapsed_time = 0;
elapsed_time += run_bf16_splitk(kernel_conv);
elapsed_time += run_type_convert(kernel_type_convert);
return elapsed_time;
}
};
if(has_main_k0_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{});
}
return launch_kernel(integral_constant<bool, true>{});
}
else
{
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
if(kbatch == 1)
{
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
OutElementwiseOperation,
InElementwiseOperation,
WeiElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
has_main_loop>;
return run_conv(kernel);
}
else
{
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
OutElementwiseOperation,
InElementwiseOperation,
WeiElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
has_main_loop>;
return run_conv(kernel);
}
};
if(has_main_k0_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{});
}
return launch_kernel(integral_constant<bool, false>{});
}
return ave_time;
}
float Run(const BaseArgument* p_arg,
......@@ -1263,7 +960,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 pad = 0 conv
for(int i = 0; i < NumDimSpatial; i++)
for(int i = 0; i < NDimSpatial; i++)
{
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
......@@ -1390,74 +1087,18 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
auto str = std::stringstream();
// clang-format off
str << "DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
str << "DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock
<< K0PerBlock << ", "
<< getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization)
<< ">";
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0){
str << " Filter1x1Stride1Pad0";
}
// clang-format on
return str.str();
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static size_t GetWorkSpaceSize(const Argument& arg)
{
size_t WorkSpaceSize = 0;
if(arg.k_batch_ > 1)
{
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
{
WorkSpaceSize =
arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] * sizeof(float);
}
}
return WorkSpaceSize;
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static size_t GetWorkSpaceSize(const Argument& arg)
{
size_t WorkSpaceSize = 0;
if(arg.k_batch_ > 1)
{
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
{
WorkSpaceSize = arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] *
arg.filter_spatial_lengths_[1] * sizeof(float);
}
}
return WorkSpaceSize;
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static size_t GetWorkSpaceSize(const Argument& arg)
{
size_t WorkSpaceSize = 0;
if(arg.k_batch_ > 1)
{
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
{
WorkSpaceSize = arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] *
arg.filter_spatial_lengths_[1] * arg.filter_spatial_lengths_[2] *
sizeof(float);
}
}
return WorkSpaceSize;
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override final
{
return GetWorkSpaceSize<NumDimSpatial>(*dynamic_cast<const Argument*>(p_arg));
}
};
} // namespace device
......
......@@ -39,7 +39,7 @@ namespace device {
// 3D:
// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
//
template <ck::index_t NumDimSpatial,
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
......@@ -74,16 +74,16 @@ template <ck::index_t NumDimSpatial,
ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector>
struct DeviceConvNdFwdNwcKxcNwk_Xdl
: public DeviceConvFwd<NumDimSpatial,
ck::tuple_element_t<NumDimSpatial - 1,
: public DeviceConvFwd<NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::NDHWC>>,
ck::tuple_element_t<NumDimSpatial - 1,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::KZYXC>>,
ck::tuple_element_t<NumDimSpatial - 1,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWK,
ck::tensor_layout::convolution::NHWK,
ck::tensor_layout::convolution::NDHWK>>,
......@@ -94,27 +94,6 @@ struct DeviceConvNdFwdNwcKxcNwk_Xdl
WeiElementwiseOperation,
OutElementwiseOperation>
{
using Base =
DeviceConvFwd<NumDimSpatial,
ck::tuple_element_t<NumDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::NDHWC>>,
ck::tuple_element_t<NumDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::KZYXC>>,
ck::tuple_element_t<NumDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWK,
ck::tensor_layout::convolution::NHWK,
ck::tensor_layout::convolution::NDHWK>>,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>;
using DeviceOp = DeviceConvNdFwdNwcKxcNwk_Xdl;
using ADataType = InDataType;
......@@ -124,8 +103,6 @@ struct DeviceConvNdFwdNwcKxcNwk_Xdl
// TODO make A/B datatype different
using ABDataType = InDataType;
static constexpr index_t NDimSpatial = NumDimSpatial;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
......@@ -599,18 +576,18 @@ struct DeviceConvNdFwdNwcKxcNwk_Xdl
// C = A^T*B
// A:
const auto in_gemmk0_gemmm_gemmk1_grid_desc =
GetInputTensorDescriptor<NumDimSpatial>(N,
C,
GemmMRaw,
GemmK,
GemmMPad,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
GetInputTensorDescriptor<NDimSpatial>(N,
C,
GemmMRaw,
GemmK,
GemmMPad,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
// B:
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = GetWeightTensorDescriptor(GemmN, GemmK);
// C:
......@@ -642,7 +619,7 @@ struct DeviceConvNdFwdNwcKxcNwk_Xdl
1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
}
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
......@@ -934,7 +911,7 @@ struct DeviceConvNdFwdNwcKxcNwk_Xdl
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 conv
for(ck::index_t i = 0; i < NumDimSpatial; ++i)
for(ck::index_t i = 0; i < NDimSpatial; ++i)
{
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
......@@ -947,7 +924,7 @@ struct DeviceConvNdFwdNwcKxcNwk_Xdl
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// check if it's 1x1 conv
for(ck::index_t i = 0; i < NumDimSpatial; ++i)
for(ck::index_t i = 0; i < NDimSpatial; ++i)
{
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.input_left_pads_[i] == 0 &&
arg.input_right_pads_[i] == 0))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment