"tests/pipelines/vscode:/vscode.git/clone" did not exist on "0c6d1bc985d2373d742d323283994f3dc2e50965"
Commit 5decb4a7 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Merge branch 'feature/integrage-karg-simplification-pr' into...

Merge branch 'feature/integrage-karg-simplification-pr' into feature/simplify-karg-for-device-gemm-xdl
parents 7c2b82ca 64b9b6a0
......@@ -168,9 +168,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
// Argument
struct Argument : public GridwiseGemm::Argument
struct Argument : public tensor_operation::device::BaseArgument, public GridwiseGemm::Problem
{
using Parent = typename GridwiseGemm::Argument;
using Problem = typename GridwiseGemm::Problem;
Argument(const ADataType* p_a_grid_real_,
const ADataType* p_a_grid_imag_,
......@@ -185,7 +185,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: Parent(M_, N_, K_, StrideA_, StrideB_, StrideC_),
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
p_a_grid_real{p_a_grid_real_},
p_a_grid_imag{p_a_grid_imag_},
p_b_grid_real{p_b_grid_real_},
......@@ -225,22 +225,22 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
karg.Print();
arg.Print();
}
if(!GridwiseGemm::CheckValidity(karg))
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(karg.M, karg.N);
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
const auto K = GridwiseGemm::CalculateAK0(karg.K) * AK1;
const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1;
float ave_time = 0;
......@@ -284,27 +284,28 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_gemm_xdl_cshuffle_v1_simplified<GridwiseGemm, true>;
const auto kernel =
kernel_gemm_xdl_cshuffle_v2<GridwiseGemm, ADataType, CDataType, true>;
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_real,
karg.p_b_grid_real,
karg.p_aux_grid,
karg);
arg.p_a_grid_real,
arg.p_b_grid_real,
arg.p_aux_grid,
arg);
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_imag,
karg.p_b_grid_imag,
karg.p_aux_2_grid,
karg);
arg.p_a_grid_imag,
arg.p_b_grid_imag,
arg.p_aux_2_grid,
arg);
// c_real = aux - aux_2
ave_time += launch_and_time_kernel(
......@@ -313,11 +314,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
make_tuple(karg.c_grid_desc_m, karg.c_grid_desc_m),
make_tuple(karg.c_grid_desc_m),
make_tuple(const_cast<const CDataType*>(karg.p_aux_grid),
const_cast<const CDataType*>(karg.p_aux_2_grid)),
make_tuple(karg.p_c_grid_real),
make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m),
make_tuple(arg.c_grid_desc_m),
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
const_cast<const CDataType*>(arg.p_aux_2_grid)),
make_tuple(arg.p_c_grid_real),
Subtract{});
ave_time += launch_and_time_kernel(stream_config,
......@@ -325,20 +326,20 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_real,
karg.p_b_grid_imag,
karg.p_aux_grid,
karg);
arg.p_a_grid_real,
arg.p_b_grid_imag,
arg.p_aux_grid,
arg);
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_imag,
karg.p_b_grid_real,
karg.p_aux_2_grid,
karg);
arg.p_a_grid_imag,
arg.p_b_grid_real,
arg.p_aux_2_grid,
arg);
// c_imag = aux + aux_2
ave_time += launch_and_time_kernel(
......@@ -347,36 +348,37 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
make_tuple(karg.c_grid_desc_m, karg.c_grid_desc_m),
make_tuple(karg.c_grid_desc_m),
make_tuple(const_cast<const CDataType*>(karg.p_aux_grid),
const_cast<const CDataType*>(karg.p_aux_2_grid)),
make_tuple(karg.p_c_grid_imag),
make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m),
make_tuple(arg.c_grid_desc_m),
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
const_cast<const CDataType*>(arg.p_aux_2_grid)),
make_tuple(arg.p_c_grid_imag),
Add{});
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v1_simplified<GridwiseGemm, false>;
const auto kernel =
kernel_gemm_xdl_cshuffle_v2<GridwiseGemm, ADataType, CDataType, false>;
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_real,
karg.p_b_grid_real,
karg.p_aux_grid,
karg);
arg.p_a_grid_real,
arg.p_b_grid_real,
arg.p_aux_grid,
arg);
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_imag,
karg.p_b_grid_imag,
karg.p_aux_2_grid,
karg);
arg.p_a_grid_imag,
arg.p_b_grid_imag,
arg.p_aux_2_grid,
arg);
// c_real = aux - aux_2
ave_time += launch_and_time_kernel(
......@@ -385,11 +387,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
make_tuple(karg.c_grid_desc_m, karg.c_grid_desc_m),
make_tuple(karg.c_grid_desc_m),
make_tuple(const_cast<const CDataType*>(karg.p_aux_grid),
const_cast<const CDataType*>(karg.p_aux_2_grid)),
make_tuple(karg.p_c_grid_real),
make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m),
make_tuple(arg.c_grid_desc_m),
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
const_cast<const CDataType*>(arg.p_aux_2_grid)),
make_tuple(arg.p_c_grid_real),
Subtract{});
ave_time += launch_and_time_kernel(stream_config,
......@@ -397,20 +399,20 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_real,
karg.p_b_grid_imag,
karg.p_aux_grid,
karg);
arg.p_a_grid_real,
arg.p_b_grid_imag,
arg.p_aux_grid,
arg);
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_imag,
karg.p_b_grid_real,
karg.p_aux_2_grid,
karg);
arg.p_a_grid_imag,
arg.p_b_grid_real,
arg.p_aux_2_grid,
arg);
// c_imag = aux + aux_2
ave_time += launch_and_time_kernel(
......@@ -419,11 +421,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
make_tuple(karg.c_grid_desc_m, karg.c_grid_desc_m),
make_tuple(karg.c_grid_desc_m),
make_tuple(const_cast<const CDataType*>(karg.p_aux_grid),
const_cast<const CDataType*>(karg.p_aux_2_grid)),
make_tuple(karg.p_c_grid_imag),
make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m),
make_tuple(arg.c_grid_desc_m),
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
const_cast<const CDataType*>(arg.p_aux_2_grid)),
make_tuple(arg.p_c_grid_imag),
Add{});
}
......@@ -444,9 +446,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
return true;
}
static bool IsSupportedArgument(const Argument& karg)
static bool IsSupportedArgument(const Argument& arg)
{
return GridwiseGemm::CheckValidity(karg);
return GridwiseGemm::CheckValidity(arg);
}
// polymorphic
......
......@@ -130,80 +130,43 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
LoopSched,
PipelineVer>;
struct Argument : public GridwiseGemm::Argument
{
using Parent = typename GridwiseGemm::Argument;
Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
CDataType* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: Parent(M_, N_, K_, StrideA_, StrideB_, StrideC_),
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_}
{
}
const ADataType* p_a_grid;
const BDataType* p_b_grid;
CDataType* p_c_grid;
};
using Argument = typename GridwiseGemm::Argument;
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
karg.Print();
arg.Print();
}
if(!GridwiseGemm::CheckValidity(karg))
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(karg.M, karg.N);
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
const auto K = GridwiseGemm::CalculateAK0(karg.K) * AK1;
const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1;
float ave_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_gemm_xdl_cshuffle_v1_simplified<GridwiseGemm, true>;
const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, true>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
karg);
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v1_simplified<GridwiseGemm, false>;
const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
karg);
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
return ave_time;
......@@ -223,9 +186,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
return true;
}
static bool IsSupportedArgument(const Argument& karg)
static bool IsSupportedArgument(const Argument& arg)
{
return GridwiseGemm::CheckValidity(karg);
return GridwiseGemm::CheckValidity(arg);
}
// polymorphic
......
......@@ -25,32 +25,49 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdl_cshuffle_v1_simplified(
const typename GridwiseGemm::FloatAB* __restrict__ p_a_grid,
const typename GridwiseGemm::FloatAB* __restrict__ p_b_grid,
typename GridwiseGemm::FloatC* __restrict__ p_c_grid,
typename GridwiseGemm::Argument karg)
kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, p_b_grid, p_c_grid, p_shared, karg);
GridwiseGemm::template Run<HasMainKBlockLoop>(
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg);
#else
ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename GridwiseGemm, typename FloatAB, typename FloatC, bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdl_cshuffle_v2(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
typename GridwiseGemm::Problem problem)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, p_b_grid, p_c_grid, p_shared, problem);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = karg;
ignore = problem;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename ALayout,
typename BLayout,
typename CLayout,
typename FloatAB_,
typename FloatAB,
typename FloatGemmAcc,
typename FloatCShuffle,
typename FloatC_,
typename FloatC,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
......@@ -106,9 +123,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{};
using FloatAB = FloatAB_;
using FloatC = FloatC_;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ static auto CalculateGridSize(index_t M, index_t N)
......@@ -392,15 +406,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
}
// Argument
struct Argument : public tensor_operation::device::BaseArgument
struct Problem
{
__host__ Argument(index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
__host__ Problem(index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: M{M_},
N{N_},
K{K_},
......@@ -419,7 +432,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__host__ void Print() const
{
std::cout << "arg {"
std::cout << "problem {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
......@@ -450,6 +463,30 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
index_t NBlock;
};
// Argument
struct Argument : public tensor_operation::device::BaseArgument, public Problem
{
__host__ Argument(const FloatAB* p_a_grid_,
const FloatAB* p_b_grid_,
FloatC* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_}
{
}
const FloatAB* p_a_grid;
const FloatAB* p_b_grid;
FloatC* p_c_grid;
};
// FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
......@@ -513,7 +550,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ static constexpr bool CheckValidity(const Argument& karg)
__host__ static constexpr bool CheckValidity(const Problem& problem)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
......@@ -524,7 +561,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(karg.M % MPerBlock == 0))
if(!(problem.M % MPerBlock == 0))
{
return false;
}
......@@ -535,7 +572,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(karg.N % NPerBlock == 0))
if(!(problem.N % NPerBlock == 0))
{
return false;
}
......@@ -546,15 +583,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding)
{
if(!(CalculateKPadded(karg.K) % AK1Value == 0) ||
!(CalculateKPadded(karg.K) % BK1Value == 0))
if(!(CalculateKPadded(problem.K) % AK1Value == 0) ||
!(CalculateKPadded(problem.K) % BK1Value == 0))
{
return false;
}
}
else
{
if(!(karg.K % AK1Value == 0) || !(karg.K % BK1Value == 0))
if(!(problem.K % AK1Value == 0) || !(problem.K % BK1Value == 0))
{
return false;
}
......@@ -562,14 +599,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
if(karg.K % ABlockTransferSrcScalarPerVector != 0)
if(problem.K % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
if(karg.M % ABlockTransferSrcScalarPerVector != 0)
if(problem.M % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
......@@ -577,14 +614,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
if(karg.N % BBlockTransferSrcScalarPerVector != 0)
if(problem.N % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
if(karg.K % BBlockTransferSrcScalarPerVector != 0)
if(problem.K % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
......@@ -592,21 +629,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
if(problem.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else
{
if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
if(problem.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
// check gridwise gemm pipeline
const auto num_k_loop = (CalculateAK0(karg.K) * AK1Value) / KPerBlock;
const auto num_k_loop = (CalculateAK0(problem.K) * AK1Value) / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
......@@ -646,7 +683,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const Argument& karg)
const Problem& problem)
{
#if ENABLE_DUMP_CLOCK
__builtin_amdgcn_sched_barrier(0);
......@@ -656,15 +693,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
#endif
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
karg.M, karg.MPadded, karg.K, karg.KPadded, karg.StrideA, karg.AK0);
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
karg.K, karg.KPadded, karg.N, karg.NPadded, karg.StrideB, karg.BK0);
const auto c_grid_desc_m_n =
MakeCGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC);
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, karg.MBlock, karg.NBlock);
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
......@@ -678,7 +715,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const CElementwiseOperation c_element_op{};
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N};
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N};
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment