Commit 94e03cf5 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Remove unnecessary kernel arguments

parent d23a7617
...@@ -20,6 +20,50 @@ namespace ck { ...@@ -20,6 +20,50 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace detail {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_convnd_bwd_data_nwc_kxc_nwk_xdl(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_M_N c_grid_desc_m_n,
index_t NumKBlockLoop)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_m_n,
NumKBlockLoop);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_m_n;
ignore = NumKBlockLoop;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
} // namespace detail
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] // out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename InDataType, typename InDataType,
...@@ -967,12 +1011,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -967,12 +1011,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
{0, 0, 0}); {0, 0, 0});
} }
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
BlockSize, BlockSize,
...@@ -980,9 +1018,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -980,9 +1018,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
...@@ -1014,6 +1049,15 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1014,6 +1049,15 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
7, // CThreadTransferSrcDstVectorDim, 7, // CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>; CThreadTransferDstScalarPerVector>;
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(std::declval<ABCGridDescs>()[I0])>;
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(std::declval<ABCGridDescs>()[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(std::declval<ABCGridDescs>()[I2])>;
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = decltype(
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(std::declval<CGridDesc_M_N>()));
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -1030,19 +1074,11 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1030,19 +1074,11 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
ck::index_t M01, ck::index_t M01)
ck::index_t N01,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
: p_a_grid_{p_out_grid}, : p_a_grid_{p_out_grid},
p_b_grid_{p_wei_grid}, p_b_grid_{p_wei_grid},
p_c_grid_{p_in_grid}, p_c_grid_{p_in_grid},
M01_{M01}, M01_{M01},
N01_{N01},
a_element_op_{out_element_op},
b_element_op_{wei_element_op},
c_element_op_{in_element_op},
Conv_N_{N}, Conv_N_{N},
Conv_K_{K}, Conv_K_{K},
Conv_C_{C}, Conv_C_{C},
...@@ -1092,17 +1128,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1092,17 +1128,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]); c_grid_desc_m_n_container_.push_back(descs[I2]);
auto block_2_ctile_map =
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_);
if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2], block_2_ctile_map))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
block_2_ctile_map_container_.push_back(block_2_ctile_map);
}
} }
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
...@@ -1150,18 +1175,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1150,18 +1175,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]); c_grid_desc_m_n_container_.push_back(descs[I2]);
auto block_2_ctile_map =
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_);
if(GridwiseGemm::CheckValidity(
descs[I0], descs[I1], descs[I2], block_2_ctile_map))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
block_2_ctile_map_container_.push_back(block_2_ctile_map);
}
} }
} }
} }
...@@ -1218,19 +1231,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1218,19 +1231,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]); c_grid_desc_m_n_container_.push_back(descs[I2]);
auto block_2_ctile_map =
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_);
if(GridwiseGemm::CheckValidity(
descs[I0], descs[I1], descs[I2], block_2_ctile_map))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
descs[I2]));
block_2_ctile_map_container_.push_back(block_2_ctile_map);
}
} }
} }
} }
...@@ -1242,14 +1242,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1242,14 +1242,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
std::vector<AGridDesc_K0_M_K1> a_grid_desc_k0_m_k1_container_; std::vector<AGridDesc_K0_M_K1> a_grid_desc_k0_m_k1_container_;
std::vector<BGridDesc_K0_N_K1> b_grid_desc_k0_n_k1_container_; std::vector<BGridDesc_K0_N_K1> b_grid_desc_k0_n_k1_container_;
std::vector<CGridDesc_M_N> c_grid_desc_m_n_container_; std::vector<CGridDesc_M_N> c_grid_desc_m_n_container_;
std::vector<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_;
std::vector<typename GridwiseGemm::DefaultBlock2CTileMap> block_2_ctile_map_container_;
index_t M01_; index_t M01_;
index_t N01_;
OutElementwiseOperation a_element_op_;
WeiElementwiseOperation b_element_op_;
InElementwiseOperation c_element_op_;
// for checking IsSupportedArgument() // for checking IsSupportedArgument()
index_t Conv_N_; index_t Conv_N_;
index_t Conv_K_; index_t Conv_K_;
...@@ -1315,39 +1308,34 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1315,39 +1308,34 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i], arg.c_grid_desc_m_n_container_[i]))
arg.block_2_ctile_map_container_[i]))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
} }
const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize( index_t gdx, gdy, gdz;
arg.c_grid_desc_m_n_container_[i]); std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
arg.c_grid_desc_m_n_container_[i].GetLength(I0),
arg.c_grid_desc_m_n_container_[i].GetLength(I1));
const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) * const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) *
arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2); arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdlops_v2r3< const auto kernel = detail::kernel_convnd_bwd_data_nwc_kxc_nwk_xdl<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, AGridDesc_K0_M_K1,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, BGridDesc_K0_N_K1,
remove_reference_t< CGridDesc_M_N,
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
OutElementwiseOperation,
WeiElementwiseOperation,
InElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>; true>;
ave_time += launch_and_time_kernel( ave_time += launch_and_time_kernel(stream_config,
stream_config,
kernel, kernel,
dim3(grid_size), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
...@@ -1355,32 +1343,24 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1355,32 +1343,24 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i], arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], arg.c_grid_desc_m_n_container_[i],
arg.a_element_op_, GridwiseGemm::CalculateNumKBlockLoop(K));
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_container_[i]);
} }
else else
{ {
const auto kernel = kernel_gemm_xdlops_v2r3< const auto kernel = detail::kernel_convnd_bwd_data_nwc_kxc_nwk_xdl<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, AGridDesc_K0_M_K1,
remove_reference_t< BGridDesc_K0_N_K1,
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, CGridDesc_M_N,
OutElementwiseOperation,
WeiElementwiseOperation,
InElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>; false>;
ave_time += launch_and_time_kernel( ave_time += launch_and_time_kernel(stream_config,
stream_config,
kernel, kernel,
dim3(grid_size), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
...@@ -1388,11 +1368,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1388,11 +1368,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i], arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], arg.c_grid_desc_m_n_container_[i],
arg.a_element_op_, GridwiseGemm::CalculateNumKBlockLoop(K));
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_container_[i]);
} }
} }
return ave_time; return ave_time;
...@@ -1446,8 +1423,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1446,8 +1423,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i], arg.c_grid_desc_m_n_container_[i]))
arg.block_2_ctile_map_container_[i]))
{ {
return false; return false;
} }
...@@ -1472,10 +1448,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1472,10 +1448,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads)
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
{ {
return Argument{p_in_grid, return Argument{p_in_grid,
p_wei_grid, p_wei_grid,
...@@ -1490,11 +1463,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1490,11 +1463,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads,
1, 1};
1,
in_element_op,
wei_element_op,
out_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -1513,9 +1482,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1513,9 +1482,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation,
OutElementwiseOperation out_element_op) override OutElementwiseOperation) override
{ {
return std::make_unique<Argument>(static_cast<InDataType*>(p_in_grid), return std::make_unique<Argument>(static_cast<InDataType*>(p_in_grid),
static_cast<const WeiDataType*>(p_wei_grid), static_cast<const WeiDataType*>(p_wei_grid),
...@@ -1530,11 +1499,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1530,11 +1499,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads,
1, 1);
1,
in_element_op,
wei_element_op,
out_element_op);
} }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace ck { namespace ck {
template <typename GridwiseGemm, typename FloatAB, typename FloatC, bool HasMainKBlockLoop> template <typename GridwiseGemm, bool HasMainKBlockLoop>
#ifdef USE_WAVES_PER_EU #ifdef USE_WAVES_PER_EU
__attribute__((amdgpu_waves_per_eu(1, 1))) __attribute__((amdgpu_waves_per_eu(1, 1)))
#endif #endif
...@@ -45,7 +45,7 @@ __global__ void ...@@ -45,7 +45,7 @@ __global__ void
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
c_grid_desc_m_n, c_grid_desc_m_n,
karg); karg.NumKBlockLoop);
#else #else
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
...@@ -128,21 +128,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -128,21 +128,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
} }
// Argument // Argument
struct Argument : public tensor_operation::device::BaseArgument struct Problem
{ {
__host__ Argument(const FloatAB* p_a_grid_, __host__ Problem(index_t M_,
const FloatAB* p_b_grid_,
FloatC* p_c_grid_,
index_t M_,
index_t N_, index_t N_,
index_t K_, index_t K_,
index_t StrideA_, index_t StrideA_,
index_t StrideB_, index_t StrideB_,
index_t StrideC_) index_t StrideC_)
: p_a_grid{p_a_grid_}, : M{M_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
M{M_},
N{N_}, N{N_},
K{K_}, K{K_},
StrideA{StrideA_}, StrideA{StrideA_},
...@@ -157,7 +151,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -157,7 +151,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
__host__ void Print() const __host__ void Print() const
{ {
std::cout << "arg {" std::cout << "problem {"
<< "M:" << M << ", " << "M:" << M << ", "
<< "N:" << N << ", " << "N:" << N << ", "
<< "K:" << K << ", " << "K:" << K << ", "
...@@ -170,9 +164,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -170,9 +164,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<< "NumKBlockLoop: " << NumKBlockLoop << "}" << std::endl; << "NumKBlockLoop: " << NumKBlockLoop << "}" << std::endl;
} }
const FloatAB* p_a_grid;
const FloatAB* p_b_grid;
FloatC* p_c_grid;
index_t M; index_t M;
index_t N; index_t N;
index_t K; index_t K;
...@@ -185,6 +176,30 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -185,6 +176,30 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
index_t NumKBlockLoop; index_t NumKBlockLoop;
}; };
// Argument
struct Argument : public Problem, public tensor_operation::device::BaseArgument
{
__host__ Argument(const FloatAB* p_a_grid_,
const FloatAB* p_b_grid_,
FloatC* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_}
{
}
const FloatAB* p_a_grid;
const FloatAB* p_b_grid;
FloatC* p_c_grid;
};
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
...@@ -260,15 +275,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -260,15 +275,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB); return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB);
} }
template <typename AGridDesc_K0_M_K1, template <typename AGridDesc_K0_M_K1, typename BGridDesc_K0_N_K1, typename CGridDesc_M_N>
typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N,
typename Block2CTileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n, const CGridDesc_M_N& c_grid_desc_m_n)
const Block2CTileMap& block_2_ctile_map)
{ {
const auto M = a_grid_desc_k0_m_k1.GetLength(I1); const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1); const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
...@@ -290,17 +301,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -290,17 +301,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return false; return false;
} }
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true; return true;
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ static constexpr bool CheckValidity(const Argument& karg) __host__ static constexpr bool CheckValidity(const Problem& problem)
{ {
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value, static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time"); "wrong! K1 need to be known at compile-time");
...@@ -310,7 +316,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -310,7 +316,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
"Invalid tuning param!"); "Invalid tuning param!");
// check gridwise gemm pipeline // check gridwise gemm pipeline
const index_t K0 = karg.K / K1Value; const index_t K0 = problem.K / K1Value;
const auto num_k_loop = K0 / K0PerBlock; const auto num_k_loop = K0 / K0PerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop)) if(!GridwiseGemmPipe::IsSupported(num_k_loop))
...@@ -394,7 +400,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -394,7 +400,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n, const CGridDesc_M_N& c_grid_desc_m_n,
const Argument& karg) index_t NumKBlockLoop)
{ {
#if ENABLE_DUMP_CLOCK #if ENABLE_DUMP_CLOCK
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -417,7 +423,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -417,7 +423,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const BElementwiseOperation b_element_op{}; const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{}; const CElementwiseOperation c_element_op{};
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N}; const auto block_2_ctile_map =
Block2CTileMap{c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)};
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx =
...@@ -546,7 +553,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -546,7 +553,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
// gridwise GEMM pipeline // gridwise GEMM pipeline
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(karg.NumKBlockLoop); const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(NumKBlockLoop);
#if ENABLE_DUMP_CLOCK #if ENABLE_DUMP_CLOCK
long loop_start = 0, loop_end = 0; long loop_start = 0, loop_end = 0;
...@@ -790,8 +797,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext ...@@ -790,8 +797,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
LoopSched, LoopSched,
PipelineVer>; PipelineVer>;
using typename Parent::Argument;
using typename Parent::GridwiseGemmPipe; using typename Parent::GridwiseGemmPipe;
using typename Parent::Problem;
using Parent::I1; using Parent::I1;
...@@ -899,7 +906,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext ...@@ -899,7 +906,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ static constexpr bool CheckValidity(const Argument& karg) __host__ static constexpr bool CheckValidity(const Problem& problem)
{ {
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value, static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time"); "wrong! K1 need to be known at compile-time");
...@@ -913,7 +920,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext ...@@ -913,7 +920,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{ {
if(!(karg.M % MPerBlock == 0)) if(!(problem.M % MPerBlock == 0))
{ {
return false; return false;
} }
...@@ -924,7 +931,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext ...@@ -924,7 +931,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{ {
if(!(karg.N % NPerBlock == 0)) if(!(problem.N % NPerBlock == 0))
{ {
return false; return false;
} }
...@@ -932,14 +939,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext ...@@ -932,14 +939,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{ {
if(karg.K % ABlockTransferSrcScalarPerVector != 0) if(problem.K % ABlockTransferSrcScalarPerVector != 0)
{ {
return false; return false;
} }
} }
else else
{ {
if(karg.M % ABlockTransferSrcScalarPerVector != 0) if(problem.M % ABlockTransferSrcScalarPerVector != 0)
{ {
return false; return false;
} }
...@@ -947,21 +954,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext ...@@ -947,21 +954,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
if(karg.N % BBlockTransferSrcScalarPerVector != 0) if(problem.N % BBlockTransferSrcScalarPerVector != 0)
{ {
return false; return false;
} }
} }
else else
{ {
if(karg.K % BBlockTransferSrcScalarPerVector != 0) if(problem.K % BBlockTransferSrcScalarPerVector != 0)
{ {
return false; return false;
} }
} }
// check gridwise gemm pipeline // check gridwise gemm pipeline
const index_t K0 = karg.K / K1; const index_t K0 = problem.K / K1;
const auto num_k_loop = K0 / K0PerBlock; const auto num_k_loop = K0 / K0PerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop)) if(!GridwiseGemmPipe::IsSupported(num_k_loop))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment