"sims/vscode:/vscode.git/clone" did not exist on "475b21b4f76924ef812f90911dece401bd20b502"
Commit 94e03cf5 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Remove unnecessary kernel arguments

parent d23a7617
......@@ -20,6 +20,50 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace detail {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_convnd_bwd_data_nwc_kxc_nwk_xdl(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_M_N c_grid_desc_m_n,
index_t NumKBlockLoop)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_m_n,
NumKBlockLoop);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_m_n;
ignore = NumKBlockLoop;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
} // namespace detail
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template <ck::index_t NDimSpatial,
typename InDataType,
......@@ -967,12 +1011,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
{0, 0, 0});
}
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
BlockSize,
......@@ -980,9 +1018,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
......@@ -1014,6 +1049,15 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
7, // CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(std::declval<ABCGridDescs>()[I0])>;
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(std::declval<ABCGridDescs>()[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(std::declval<ABCGridDescs>()[I2])>;
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = decltype(
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(std::declval<CGridDesc_M_N>()));
// Argument
struct Argument : public BaseArgument
{
......@@ -1030,19 +1074,11 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
ck::index_t M01,
ck::index_t N01,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
ck::index_t M01)
: p_a_grid_{p_out_grid},
p_b_grid_{p_wei_grid},
p_c_grid_{p_in_grid},
M01_{M01},
N01_{N01},
a_element_op_{out_element_op},
b_element_op_{wei_element_op},
c_element_op_{in_element_op},
Conv_N_{N},
Conv_K_{K},
Conv_C_{C},
......@@ -1092,17 +1128,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
auto block_2_ctile_map =
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_);
if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2], block_2_ctile_map))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
block_2_ctile_map_container_.push_back(block_2_ctile_map);
}
}
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
......@@ -1150,18 +1175,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
auto block_2_ctile_map =
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_);
if(GridwiseGemm::CheckValidity(
descs[I0], descs[I1], descs[I2], block_2_ctile_map))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
block_2_ctile_map_container_.push_back(block_2_ctile_map);
}
}
}
}
......@@ -1218,19 +1231,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
auto block_2_ctile_map =
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_);
if(GridwiseGemm::CheckValidity(
descs[I0], descs[I1], descs[I2], block_2_ctile_map))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
descs[I2]));
block_2_ctile_map_container_.push_back(block_2_ctile_map);
}
}
}
}
......@@ -1242,14 +1242,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
std::vector<AGridDesc_K0_M_K1> a_grid_desc_k0_m_k1_container_;
std::vector<BGridDesc_K0_N_K1> b_grid_desc_k0_n_k1_container_;
std::vector<CGridDesc_M_N> c_grid_desc_m_n_container_;
std::vector<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_;
std::vector<typename GridwiseGemm::DefaultBlock2CTileMap> block_2_ctile_map_container_;
index_t M01_;
index_t N01_;
OutElementwiseOperation a_element_op_;
WeiElementwiseOperation b_element_op_;
InElementwiseOperation c_element_op_;
// for checking IsSupportedArgument()
index_t Conv_N_;
index_t Conv_K_;
......@@ -1315,39 +1308,34 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i],
arg.block_2_ctile_map_container_[i]))
arg.c_grid_desc_m_n_container_[i]))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
}
const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize(
arg.c_grid_desc_m_n_container_[i]);
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
arg.c_grid_desc_m_n_container_[i].GetLength(I0),
arg.c_grid_desc_m_n_container_[i].GetLength(I1));
const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) *
arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_gemm_xdlops_v2r3<
const auto kernel = detail::kernel_convnd_bwd_data_nwc_kxc_nwk_xdl<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
OutElementwiseOperation,
WeiElementwiseOperation,
InElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
true>;
ave_time += launch_and_time_kernel(
stream_config,
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_a_grid_,
......@@ -1355,32 +1343,24 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i],
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_container_[i]);
arg.c_grid_desc_m_n_container_[i],
GridwiseGemm::CalculateNumKBlockLoop(K));
}
else
{
const auto kernel = kernel_gemm_xdlops_v2r3<
const auto kernel = detail::kernel_convnd_bwd_data_nwc_kxc_nwk_xdl<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
OutElementwiseOperation,
WeiElementwiseOperation,
InElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
false>;
ave_time += launch_and_time_kernel(
stream_config,
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_a_grid_,
......@@ -1388,11 +1368,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i],
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_container_[i]);
arg.c_grid_desc_m_n_container_[i],
GridwiseGemm::CalculateNumKBlockLoop(K));
}
}
return ave_time;
......@@ -1446,8 +1423,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i],
arg.block_2_ctile_map_container_[i]))
arg.c_grid_desc_m_n_container_[i]))
{
return false;
}
......@@ -1472,10 +1448,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
std::vector<ck::index_t> input_right_pads)
{
return Argument{p_in_grid,
p_wei_grid,
......@@ -1490,11 +1463,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
conv_filter_dilations,
input_left_pads,
input_right_pads,
1,
1,
in_element_op,
wei_element_op,
out_element_op};
1};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -1513,9 +1482,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) override
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation) override
{
return std::make_unique<Argument>(static_cast<InDataType*>(p_in_grid),
static_cast<const WeiDataType*>(p_wei_grid),
......@@ -1530,11 +1499,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
conv_filter_dilations,
input_left_pads,
input_right_pads,
1,
1,
in_element_op,
wei_element_op,
out_element_op);
1);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......
......@@ -17,7 +17,7 @@
namespace ck {
template <typename GridwiseGemm, typename FloatAB, typename FloatC, bool HasMainKBlockLoop>
template <typename GridwiseGemm, bool HasMainKBlockLoop>
#ifdef USE_WAVES_PER_EU
__attribute__((amdgpu_waves_per_eu(1, 1)))
#endif
......@@ -45,7 +45,7 @@ __global__ void
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_m_n,
karg);
karg.NumKBlockLoop);
#else
ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
......@@ -128,21 +128,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
// Argument
struct Argument : public tensor_operation::device::BaseArgument
struct Problem
{
__host__ Argument(const FloatAB* p_a_grid_,
const FloatAB* p_b_grid_,
FloatC* p_c_grid_,
index_t M_,
__host__ Problem(index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
M{M_},
: M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
......@@ -157,7 +151,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
__host__ void Print() const
{
std::cout << "arg {"
std::cout << "problem {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
......@@ -170,9 +164,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<< "NumKBlockLoop: " << NumKBlockLoop << "}" << std::endl;
}
const FloatAB* p_a_grid;
const FloatAB* p_b_grid;
FloatC* p_c_grid;
index_t M;
index_t N;
index_t K;
......@@ -185,6 +176,30 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
index_t NumKBlockLoop;
};
// Argument
struct Argument : public Problem, public tensor_operation::device::BaseArgument
{
__host__ Argument(const FloatAB* p_a_grid_,
const FloatAB* p_b_grid_,
FloatC* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_}
{
}
const FloatAB* p_a_grid;
const FloatAB* p_b_grid;
FloatC* p_c_grid;
};
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
......@@ -260,15 +275,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB);
}
template <typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N,
typename Block2CTileMap>
template <typename AGridDesc_K0_M_K1, typename BGridDesc_K0_N_K1, typename CGridDesc_M_N>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
......@@ -290,17 +301,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return false;
}
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ static constexpr bool CheckValidity(const Argument& karg)
__host__ static constexpr bool CheckValidity(const Problem& problem)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
......@@ -310,7 +316,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
"Invalid tuning param!");
// check gridwise gemm pipeline
const index_t K0 = karg.K / K1Value;
const index_t K0 = problem.K / K1Value;
const auto num_k_loop = K0 / K0PerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
......@@ -394,7 +400,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n,
const Argument& karg)
index_t NumKBlockLoop)
{
#if ENABLE_DUMP_CLOCK
__builtin_amdgcn_sched_barrier(0);
......@@ -417,7 +423,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N};
const auto block_2_ctile_map =
Block2CTileMap{c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)};
// divide block work by [M, N]
const auto block_work_idx =
......@@ -546,7 +553,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
// gridwise GEMM pipeline
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(karg.NumKBlockLoop);
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(NumKBlockLoop);
#if ENABLE_DUMP_CLOCK
long loop_start = 0, loop_end = 0;
......@@ -790,8 +797,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
LoopSched,
PipelineVer>;
using typename Parent::Argument;
using typename Parent::GridwiseGemmPipe;
using typename Parent::Problem;
using Parent::I1;
......@@ -899,7 +906,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ static constexpr bool CheckValidity(const Argument& karg)
__host__ static constexpr bool CheckValidity(const Problem& problem)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
......@@ -913,7 +920,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(karg.M % MPerBlock == 0))
if(!(problem.M % MPerBlock == 0))
{
return false;
}
......@@ -924,7 +931,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(karg.N % NPerBlock == 0))
if(!(problem.N % NPerBlock == 0))
{
return false;
}
......@@ -932,14 +939,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
if(karg.K % ABlockTransferSrcScalarPerVector != 0)
if(problem.K % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
if(karg.M % ABlockTransferSrcScalarPerVector != 0)
if(problem.M % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
......@@ -947,21 +954,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
if(karg.N % BBlockTransferSrcScalarPerVector != 0)
if(problem.N % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
if(karg.K % BBlockTransferSrcScalarPerVector != 0)
if(problem.K % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
// check gridwise gemm pipeline
const index_t K0 = karg.K / K1;
const index_t K0 = problem.K / K1;
const auto num_k_loop = K0 / K0PerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment