Commit d4efc6a7 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Create descriptors on device side

parent e090e72a
...@@ -21,47 +21,6 @@ namespace tensor_operation { ...@@ -21,47 +21,6 @@ namespace tensor_operation {
namespace device { namespace device {
namespace detail { namespace detail {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_convnd_bwd_data_nwc_kxc_nwk_xdl(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_M_N c_grid_desc_m_n,
index_t NumKBlockLoop)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_m_n,
NumKBlockLoop);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_m_n;
ignore = NumKBlockLoop;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <size_t Dim> template <size_t Dim>
struct KernelArgument struct KernelArgument
...@@ -115,6 +74,48 @@ struct KernelArgument ...@@ -115,6 +74,48 @@ struct KernelArgument
}; };
} // namespace detail } // namespace detail
template <index_t NDim,
typename DeviceOp,
typename GridwiseGemm,
typename FloatAB,
typename FloatC,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_convnd_bwd_data_nwc_kxc_nwk_xdl(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
detail::KernelArgument<NDim> karg,
index_t NumKBlockLoop)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const auto descs =
DeviceOp::template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDim>(karg);
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
descs[Number<0>{}],
descs[Number<1>{}],
descs[Number<2>{}],
NumKBlockLoop);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_m_n;
ignore = NumKBlockLoop;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] // out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename InDataType, typename InDataType,
...@@ -201,7 +202,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -201,7 +202,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
static constexpr auto GemmK1Number = K1Number; static constexpr auto GemmK1Number = K1Number;
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(detail::KernelArgument<NDim> karg) static __host__ __device__ auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(detail::KernelArgument<NDim> karg)
{ {
using namespace ck; using namespace ck;
...@@ -390,7 +392,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -390,7 +392,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(detail::KernelArgument<NDim> karg) static __host__ __device__ auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(detail::KernelArgument<NDim> karg)
{ {
using namespace ck; using namespace ck;
...@@ -650,7 +653,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -650,7 +653,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(detail::KernelArgument<NDim> karg) static __host__ __device__ auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(detail::KernelArgument<NDim> karg)
{ {
using namespace ck; using namespace ck;
...@@ -1127,8 +1131,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1127,8 +1131,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
continue; continue;
} }
const auto descs = karg_container_.push_back(
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
detail::KernelArgument<NDimSpatial>(Conv_N_, detail::KernelArgument<NDimSpatial>(Conv_N_,
Conv_K_, Conv_K_,
Conv_C_, Conv_C_,
...@@ -1141,7 +1144,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1141,7 +1144,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
input_right_pads_, input_right_pads_,
{i_xtilde}, {i_xtilde},
GridwiseGemm::CalculateK0(Conv_K_))); GridwiseGemm::CalculateK0(Conv_K_)));
grid_desc_container_.push_back(descs);
} }
} }
...@@ -1174,10 +1176,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1174,10 +1176,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
continue; continue;
} }
const auto descs = karg_container_.push_back(
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>( detail::KernelArgument<NDimSpatial>(Conv_N_,
detail::KernelArgument<NDimSpatial>(
Conv_N_,
Conv_K_, Conv_K_,
Conv_C_, Conv_C_,
input_spatial_lengths_, input_spatial_lengths_,
...@@ -1189,7 +1189,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1189,7 +1189,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
input_right_pads_, input_right_pads_,
{i_ytilde, i_xtilde}, {i_ytilde, i_xtilde},
GridwiseGemm::CalculateK0(Conv_K_))); GridwiseGemm::CalculateK0(Conv_K_)));
grid_desc_container_.push_back(descs);
} }
} }
} }
...@@ -1231,9 +1230,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1231,9 +1230,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
continue; continue;
} }
const auto descs = karg_container_.push_back(detail::KernelArgument<NDimSpatial>(
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
detail::KernelArgument<NDimSpatial>(
Conv_N_, Conv_N_,
Conv_K_, Conv_K_,
Conv_C_, Conv_C_,
...@@ -1246,7 +1243,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1246,7 +1243,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
input_right_pads_, input_right_pads_,
{i_ztilde, i_ytilde, i_xtilde}, {i_ztilde, i_ytilde, i_xtilde},
GridwiseGemm::CalculateK0(Conv_K_))); GridwiseGemm::CalculateK0(Conv_K_)));
grid_desc_container_.push_back(descs);
} }
} }
} }
...@@ -1255,7 +1251,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1255,7 +1251,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
std::vector<ABCGridDescs> grid_desc_container_; std::vector<detail::KernelArgument<NDimSpatial>> karg_container_;
index_t M01_; index_t M01_;
// for checking IsSupportedArgument() // for checking IsSupportedArgument()
index_t Conv_N_; index_t Conv_N_;
...@@ -1279,35 +1275,28 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1279,35 +1275,28 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
float ave_time = 0; float ave_time = 0;
for(size_t i = 0; i < arg.grid_desc_container_.size(); i++) for(size_t i = 0; i < arg.karg_container_.size(); i++)
{ {
auto a_grid_desc_k0_m_k1 = arg.grid_desc_container_[i][I0]; const auto K = arg.Conv_K_;
auto b_grid_desc_k0_n_k1 = arg.grid_desc_container_[i][I1];
auto c_grid_desc_m_n = arg.grid_desc_container_[i][I2];
if(!GridwiseGemm::CheckValidity( const auto descs =
a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, c_grid_desc_m_n)) DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
{ arg.karg_container_[i]);
throw std::runtime_error( const auto c_grid_desc_m_n = descs[I2];
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
}
index_t gdx, gdy, gdz; index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize( std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
const auto K =
a_grid_desc_k0_m_k1.GetLength(I0) * a_grid_desc_k0_m_k1.GetLength(I2);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = detail::kernel_convnd_bwd_data_nwc_kxc_nwk_xdl< const auto kernel =
kernel_convnd_bwd_data_nwc_kxc_nwk_xdl<NDimSpatial,
DeviceOp,
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B
// datatype
CDataType, CDataType,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
true>; true>;
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
...@@ -1318,21 +1307,18 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1318,21 +1307,18 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
a_grid_desc_k0_m_k1, arg.karg_container_[i],
b_grid_desc_k0_n_k1,
c_grid_desc_m_n,
GridwiseGemm::CalculateNumKBlockLoop(K)); GridwiseGemm::CalculateNumKBlockLoop(K));
} }
else else
{ {
const auto kernel = detail::kernel_convnd_bwd_data_nwc_kxc_nwk_xdl< const auto kernel =
kernel_convnd_bwd_data_nwc_kxc_nwk_xdl<NDimSpatial,
DeviceOp,
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B
// datatype
CDataType, CDataType,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
false>; false>;
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
...@@ -1343,9 +1329,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1343,9 +1329,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
a_grid_desc_k0_m_k1, arg.karg_container_[i],
b_grid_desc_k0_n_k1,
c_grid_desc_m_n,
GridwiseGemm::CalculateNumKBlockLoop(K)); GridwiseGemm::CalculateNumKBlockLoop(K));
} }
} }
...@@ -1395,16 +1379,16 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1395,16 +1379,16 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
return false; return false;
} }
// Gridwise GEMM size // // Gridwise GEMM size
for(std::size_t i = 0; i < arg.grid_desc_container_.size(); i++) // for(std::size_t i = 0; i < arg.grid_desc_container_.size(); i++)
{ // {
if(!GridwiseGemm::CheckValidity(arg.grid_desc_container_[i][I0], // if(!GridwiseGemm::CheckValidity(arg.grid_desc_container_[i][I0],
arg.grid_desc_container_[i][I1], // arg.grid_desc_container_[i][I1],
arg.grid_desc_container_[i][I2])) // arg.grid_desc_container_[i][I2]))
{ // {
return false; // return false;
} // }
} // }
return true; return true;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment