Commit d4efc6a7 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Create descriptors on device side

parent e090e72a
......@@ -21,47 +21,6 @@ namespace tensor_operation {
namespace device {
namespace detail {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_convnd_bwd_data_nwc_kxc_nwk_xdl(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_M_N c_grid_desc_m_n,
index_t NumKBlockLoop)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_m_n,
NumKBlockLoop);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_m_n;
ignore = NumKBlockLoop;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <size_t Dim>
struct KernelArgument
......@@ -115,6 +74,48 @@ struct KernelArgument
};
} // namespace detail
template <index_t NDim,
typename DeviceOp,
typename GridwiseGemm,
typename FloatAB,
typename FloatC,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_convnd_bwd_data_nwc_kxc_nwk_xdl(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
detail::KernelArgument<NDim> karg,
index_t NumKBlockLoop)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const auto descs =
DeviceOp::template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDim>(karg);
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
descs[Number<0>{}],
descs[Number<1>{}],
descs[Number<2>{}],
NumKBlockLoop);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_m_n;
ignore = NumKBlockLoop;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template <ck::index_t NDimSpatial,
typename InDataType,
......@@ -201,7 +202,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
static constexpr auto GemmK1Number = K1Number;
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(detail::KernelArgument<NDim> karg)
static __host__ __device__ auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(detail::KernelArgument<NDim> karg)
{
using namespace ck;
......@@ -390,7 +392,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(detail::KernelArgument<NDim> karg)
static __host__ __device__ auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(detail::KernelArgument<NDim> karg)
{
using namespace ck;
......@@ -650,7 +653,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(detail::KernelArgument<NDim> karg)
static __host__ __device__ auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(detail::KernelArgument<NDim> karg)
{
using namespace ck;
......@@ -1127,21 +1131,19 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
continue;
}
const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
detail::KernelArgument<NDimSpatial>(Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_,
{i_xtilde},
GridwiseGemm::CalculateK0(Conv_K_)));
grid_desc_container_.push_back(descs);
karg_container_.push_back(
detail::KernelArgument<NDimSpatial>(Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_,
{i_xtilde},
GridwiseGemm::CalculateK0(Conv_K_)));
}
}
......@@ -1174,22 +1176,19 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
continue;
}
const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
detail::KernelArgument<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_,
{i_ytilde, i_xtilde},
GridwiseGemm::CalculateK0(Conv_K_)));
grid_desc_container_.push_back(descs);
karg_container_.push_back(
detail::KernelArgument<NDimSpatial>(Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_,
{i_ytilde, i_xtilde},
GridwiseGemm::CalculateK0(Conv_K_)));
}
}
}
......@@ -1231,22 +1230,19 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
continue;
}
const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
detail::KernelArgument<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_,
{i_ztilde, i_ytilde, i_xtilde},
GridwiseGemm::CalculateK0(Conv_K_)));
grid_desc_container_.push_back(descs);
karg_container_.push_back(detail::KernelArgument<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_,
{i_ztilde, i_ytilde, i_xtilde},
GridwiseGemm::CalculateK0(Conv_K_)));
}
}
}
......@@ -1255,7 +1251,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
std::vector<ABCGridDescs> grid_desc_container_;
std::vector<detail::KernelArgument<NDimSpatial>> karg_container_;
index_t M01_;
// for checking IsSupportedArgument()
index_t Conv_N_;
......@@ -1279,36 +1275,29 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float ave_time = 0;
for(size_t i = 0; i < arg.grid_desc_container_.size(); i++)
for(size_t i = 0; i < arg.karg_container_.size(); i++)
{
auto a_grid_desc_k0_m_k1 = arg.grid_desc_container_[i][I0];
auto b_grid_desc_k0_n_k1 = arg.grid_desc_container_[i][I1];
auto c_grid_desc_m_n = arg.grid_desc_container_[i][I2];
const auto K = arg.Conv_K_;
if(!GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, c_grid_desc_m_n))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
}
const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
arg.karg_container_[i]);
const auto c_grid_desc_m_n = descs[I2];
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
const auto K =
a_grid_desc_k0_m_k1.GetLength(I0) * a_grid_desc_k0_m_k1.GetLength(I2);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = detail::kernel_convnd_bwd_data_nwc_kxc_nwk_xdl<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
true>;
const auto kernel =
kernel_convnd_bwd_data_nwc_kxc_nwk_xdl<NDimSpatial,
DeviceOp,
GridwiseGemm,
ADataType, // TODO: distiguish A/B
// datatype
CDataType,
true>;
ave_time += launch_and_time_kernel(stream_config,
kernel,
......@@ -1318,22 +1307,19 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_m_n,
arg.karg_container_[i],
GridwiseGemm::CalculateNumKBlockLoop(K));
}
else
{
const auto kernel = detail::kernel_convnd_bwd_data_nwc_kxc_nwk_xdl<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
false>;
const auto kernel =
kernel_convnd_bwd_data_nwc_kxc_nwk_xdl<NDimSpatial,
DeviceOp,
GridwiseGemm,
ADataType, // TODO: distiguish A/B
// datatype
CDataType,
false>;
ave_time += launch_and_time_kernel(stream_config,
kernel,
......@@ -1343,9 +1329,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_m_n,
arg.karg_container_[i],
GridwiseGemm::CalculateNumKBlockLoop(K));
}
}
......@@ -1395,16 +1379,16 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
return false;
}
// Gridwise GEMM size
for(std::size_t i = 0; i < arg.grid_desc_container_.size(); i++)
{
if(!GridwiseGemm::CheckValidity(arg.grid_desc_container_[i][I0],
arg.grid_desc_container_[i][I1],
arg.grid_desc_container_[i][I2]))
{
return false;
}
}
// // Gridwise GEMM size
// for(std::size_t i = 0; i < arg.grid_desc_container_.size(); i++)
// {
// if(!GridwiseGemm::CheckValidity(arg.grid_desc_container_[i][I0],
// arg.grid_desc_container_[i][I1],
// arg.grid_desc_container_[i][I2]))
// {
// return false;
// }
// }
return true;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment