Commit bbe5c0c7 authored by Jing Zhang's avatar Jing Zhang
Browse files

2 gemm test

parent 6cbb0a13
...@@ -176,6 +176,7 @@ struct gemm_desc ...@@ -176,6 +176,7 @@ struct gemm_desc
ck::index_t M, N, K; ck::index_t M, N, K;
ck::index_t StrideA, StrideB, StrideC; ck::index_t StrideA, StrideB, StrideC;
ck::index_t OffsetA, OffsetB, OffsetC; ck::index_t OffsetA, OffsetB, OffsetC;
ck::index_t BlockStart, BlockSize;
}; };
} // namespace ck } // namespace ck
......
...@@ -18,11 +18,13 @@ template <typename GridwiseGemm, ...@@ -18,11 +18,13 @@ template <typename GridwiseGemm,
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2, typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename GemmDesc,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainK0BlockLoop> bool HasMainK0BlockLoop,
index_t MaxGroupCount>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -34,6 +36,8 @@ __global__ void ...@@ -34,6 +36,8 @@ __global__ void
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const GemmDesc gemm_shapes,
const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
...@@ -43,24 +47,67 @@ __global__ void ...@@ -43,24 +47,67 @@ __global__ void
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
const index_t group_id = 0; index_t group_id = 0;
index_t block_id_grp = 0;
index_t a_offset_grp = 0;
index_t b_offset_grp = 0;
index_t c_offset_grp = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(i < group_count)
{
if(block_id >= gemm_shapes[i].BlockStart &&
block_id < (gemm_shapes[i].BlockStart + gemm_shapes[i].BlockSize))
{
group_id = i;
block_id_grp = block_id - gemm_shapes[i].BlockStart;
a_offset_grp = gemm_shapes[i].OffsetA;
b_offset_grp = gemm_shapes[i].OffsetB;
c_offset_grp = gemm_shapes[i].OffsetC;
// if(get_thread_local_1d_id() == 0)
// printf("%d %d %d %d %d %d\n",
// block_id,
// group_id,
// block_id_grp,
// a_offset_grp,
// b_offset_grp,
// c_offset_grp);
}
}
});
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
if(group_id == 0) if(group_id == 0)
GridwiseGemm::template Run<HasMainK0BlockLoop>( GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_offset_grp,
p_a_grid, p_b_grid + b_offset_grp,
p_b_grid, p_c_grid + c_offset_grp,
p_c_grid, p_shared,
p_shared, a_grid_desc_k0_m_k1[I0],
a_grid_desc_k0_m_k1[Number<0>{}], b_grid_desc_k0_n_k1[I0],
b_grid_desc_k0_n_k1[Number<0>{}], c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2[I0],
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2[Number<0>{}], a_element_op,
a_element_op, b_element_op,
b_element_op, c_element_op,
c_element_op, block_2_ctile_map[I0],
block_2_ctile_map[Number<0>{}], block_id_grp,
block_id); group_id);
else if(group_id == 1)
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_offset_grp,
p_b_grid + b_offset_grp,
p_c_grid + c_offset_grp,
p_shared,
a_grid_desc_k0_m_k1[I1],
b_grid_desc_k0_n_k1[I1],
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2[I1],
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map[I1],
block_id_grp,
group_id);
} }
template <index_t BlockSize, template <index_t BlockSize,
...@@ -360,7 +407,8 @@ struct GridwiseGroupedGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -360,7 +407,8 @@ struct GridwiseGroupedGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map,
const index_t block_id) const index_t block_id,
const index_t group_id)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
...@@ -382,6 +430,14 @@ struct GridwiseGroupedGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -382,6 +430,14 @@ struct GridwiseGroupedGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const index_t n_block_data_idx_on_grid = const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// if(get_thread_local_1d_id() == 0)
//{
// printf("m: %d n: %d k: %d\n", a_grid_desc_k0_m_k1.GetLength(I1),
// b_grid_desc_k0_n_k1.GetLength(I1), a_grid_desc_k0_m_k1.GetLength(I0));
// printf("block_work_idx: %d %d %d %d\n", group_id, block_id, block_work_idx[I0],
// block_work_idx[I1]);
//}
// lds max alignment // lds max alignment
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
......
...@@ -53,8 +53,8 @@ template <typename ADataType, ...@@ -53,8 +53,8 @@ template <typename ADataType,
bool BBlockLdsAddExtraN, bool BBlockLdsAddExtraN,
ck::index_t CThreadTransferSrcDstVectorDim, ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector, ck::index_t CThreadTransferDstScalarPerVector,
ck::index_t NumPrefetch = 1, ck::index_t NumPrefetch = 1,
ck::index_t GroupCount = 1> ck::index_t MaxGroupCount = 5>
struct DeviceGroupedGemmXdl struct DeviceGroupedGemmXdl
: public DeviceGroupedGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> : public DeviceGroupedGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{ {
...@@ -238,55 +238,62 @@ struct DeviceGroupedGemmXdl ...@@ -238,55 +238,62 @@ struct DeviceGroupedGemmXdl
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
gemm_shapes_{gemm_shapes},
M01_{M01}, M01_{M01},
N01_{N01}, N01_{N01},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op}
{ {
const index_t i = 0;
const index_t M = gemm_shapes[Number<0>{}].M;
const index_t N = gemm_shapes[Number<0>{}].N;
const index_t K = gemm_shapes[Number<0>{}].K;
const index_t StrideA = gemm_shapes[Number<0>{}].StrideA;
const index_t StrideB = gemm_shapes[Number<0>{}].StrideB;
const index_t StrideC = gemm_shapes[Number<0>{}].StrideC;
a_grid_desc_k0_m_k1_(Number<0>{}) =
DeviceGroupedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_(Number<0>{}) =
DeviceGroupedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_(Number<0>{}) =
DeviceGroupedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC);
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_[Number<0>{}],
b_grid_desc_k0_n_k1_[Number<0>{}],
c_grid_desc_m_n_[Number<0>{}],
M01_,
N01_))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_(Number<0>{}) =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
c_grid_desc_m_n_[Number<0>{}]);
block_2_ctile_map_(Number<0>{}) = GridwiseGemm::MakeDefaultBlock2CTileMap( static_for<0, MaxGroupCount, 1>{}([&](auto i) {
c_grid_desc_m_n_[Number<0>{}], M01, N01); if(i < gemm_shapes_.size())
} {
const index_t M = gemm_shapes_[i].M;
const index_t N = gemm_shapes_[i].N;
const index_t K = gemm_shapes_[i].K;
const index_t StrideA = gemm_shapes_[i].StrideA;
const index_t StrideB = gemm_shapes_[i].StrideB;
const index_t StrideC = gemm_shapes_[i].StrideC;
a_grid_desc_k0_m_k1_(i) =
DeviceGroupedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_(i) =
DeviceGroupedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_(i) =
DeviceGroupedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC);
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_[i],
b_grid_desc_k0_n_k1_[i],
c_grid_desc_m_n_[i],
M01_,
N01_))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_(i) =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
c_grid_desc_m_n_[i]);
block_2_ctile_map_(i) =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_[i], M01, N01);
}
}
});
} }
// private: // private:
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
StaticallyIndexedArray<AGridDesc_K0_M_K1, GroupCount> a_grid_desc_k0_m_k1_; StaticallyIndexedArray<AGridDesc_K0_M_K1, MaxGroupCount> a_grid_desc_k0_m_k1_;
StaticallyIndexedArray<BGridDesc_K0_N_K1, GroupCount> b_grid_desc_k0_n_k1_; StaticallyIndexedArray<BGridDesc_K0_N_K1, MaxGroupCount> b_grid_desc_k0_n_k1_;
StaticallyIndexedArray<CGridDesc_M_N, GroupCount> c_grid_desc_m_n_; StaticallyIndexedArray<CGridDesc_M_N, MaxGroupCount> c_grid_desc_m_n_;
StaticallyIndexedArray<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2, GroupCount> StaticallyIndexedArray<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
MaxGroupCount>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
StaticallyIndexedArray<typename GridwiseGemm::DefaultBlock2CTileMap, GroupCount> StaticallyIndexedArray<typename GridwiseGemm::DefaultBlock2CTileMap, MaxGroupCount>
block_2_ctile_map_; block_2_ctile_map_;
std::vector<gemm_desc> gemm_shapes_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -301,33 +308,49 @@ struct DeviceGroupedGemmXdl ...@@ -301,33 +308,49 @@ struct DeviceGroupedGemmXdl
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, int nrepeat = 1)
{ {
std::cout << "Gemm " << i << ":" << std::endl; StaticallyIndexedArray<gemm_desc, MaxGroupCount> gemm_shapes;
std::cout << "arg.a_grid_desc_k0_m_k1_{"
<< arg.a_grid_desc_k0_m_k1_[Number<0>{}].GetLength(I0) << ", " index_t grid_size = 0;
<< arg.a_grid_desc_k0_m_k1_[Number<0>{}].GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_[Number<0>{}].GetLength(I2) << "}" << std::endl; static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(i < arg.gemm_shapes_.size())
std::cout << "arg.b_grid_desc_k0_n_k1_{" {
<< arg.b_grid_desc_k0_n_k1_[Number<0>{}].GetLength(I0) << ", " std::cout << "arg.a_grid_desc_k0_m_k1_{"
<< arg.b_grid_desc_k0_n_k1_[Number<0>{}].GetLength(I1) << ", " << arg.a_grid_desc_k0_m_k1_[i].GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_[Number<0>{}].GetLength(I2) << "}" << std::endl; << arg.a_grid_desc_k0_m_k1_[i].GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_[i].GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_[Number<0>{}].GetLength(I0)
<< ", " << arg.c_grid_desc_m_n_[Number<0>{}].GetLength(I1) << "}" std::cout << "arg.b_grid_desc_k0_n_k1_{"
<< std::endl; << arg.b_grid_desc_k0_n_k1_[i].GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_[i].GetLength(I1) << ", "
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_[Number<0>{}], << arg.b_grid_desc_k0_n_k1_[i].GetLength(I2) << "}" << std::endl;
arg.b_grid_desc_k0_n_k1_[Number<0>{}],
arg.c_grid_desc_m_n_[Number<0>{}], std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_[i].GetLength(I0)
arg.M01_, << ", " << arg.c_grid_desc_m_n_[i].GetLength(I1) << "}" << std::endl;
arg.N01_))
{ if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_[i],
throw std::runtime_error( arg.b_grid_desc_k0_n_k1_[i],
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); arg.c_grid_desc_m_n_[i],
} arg.M01_,
arg.N01_))
const index_t grid_size = {
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_[Number<0>{}]); throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
const index_t grid_size_grp =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_[i]);
gemm_shapes(i) = arg.gemm_shapes_[i];
gemm_shapes(i).BlockStart = grid_size;
gemm_shapes(i).BlockSize = grid_size_grp;
grid_size += grid_size_grp;
std::cout << "group_id " << i << " BlockStart " << gemm_shapes(i).BlockStart
<< " BlockSize " << gemm_shapes(i).BlockSize << std::endl;
}
});
const auto K0 = arg.a_grid_desc_k0_m_k1_[Number<0>{}].GetLength(I0); const auto K0 = arg.a_grid_desc_k0_m_k1_[Number<0>{}].GetLength(I0);
...@@ -344,20 +367,22 @@ struct DeviceGroupedGemmXdl ...@@ -344,20 +367,22 @@ struct DeviceGroupedGemmXdl
CDataType, CDataType,
remove_reference_t< remove_reference_t<
StaticallyIndexedArray<DeviceGroupedGemmXdl::AGridDesc_K0_M_K1, StaticallyIndexedArray<DeviceGroupedGemmXdl::AGridDesc_K0_M_K1,
GroupCount>>, MaxGroupCount>>,
remove_reference_t< remove_reference_t<
StaticallyIndexedArray<DeviceGroupedGemmXdl::BGridDesc_K0_N_K1, StaticallyIndexedArray<DeviceGroupedGemmXdl::BGridDesc_K0_N_K1,
GroupCount>>, MaxGroupCount>>,
remove_reference_t<StaticallyIndexedArray< remove_reference_t<StaticallyIndexedArray<
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2, typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
GroupCount>>, MaxGroupCount>>,
remove_reference_t<StaticallyIndexedArray<gemm_desc, MaxGroupCount>>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
remove_reference_t< remove_reference_t<
StaticallyIndexedArray<typename GridwiseGemm::DefaultBlock2CTileMap, StaticallyIndexedArray<typename GridwiseGemm::DefaultBlock2CTileMap,
GroupCount>>, MaxGroupCount>>,
true>; true,
MaxGroupCount>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -370,6 +395,8 @@ struct DeviceGroupedGemmXdl ...@@ -370,6 +395,8 @@ struct DeviceGroupedGemmXdl
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
gemm_shapes,
arg.gemm_shapes_.size(),
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
...@@ -383,20 +410,22 @@ struct DeviceGroupedGemmXdl ...@@ -383,20 +410,22 @@ struct DeviceGroupedGemmXdl
CDataType, CDataType,
remove_reference_t< remove_reference_t<
StaticallyIndexedArray<DeviceGroupedGemmXdl::AGridDesc_K0_M_K1, StaticallyIndexedArray<DeviceGroupedGemmXdl::AGridDesc_K0_M_K1,
GroupCount>>, MaxGroupCount>>,
remove_reference_t< remove_reference_t<
StaticallyIndexedArray<DeviceGroupedGemmXdl::BGridDesc_K0_N_K1, StaticallyIndexedArray<DeviceGroupedGemmXdl::BGridDesc_K0_N_K1,
GroupCount>>, MaxGroupCount>>,
remove_reference_t<StaticallyIndexedArray< remove_reference_t<StaticallyIndexedArray<
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2, typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
GroupCount>>, MaxGroupCount>>,
remove_reference_t<StaticallyIndexedArray<gemm_desc, MaxGroupCount>>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
remove_reference_t< remove_reference_t<
StaticallyIndexedArray<typename GridwiseGemm::DefaultBlock2CTileMap, StaticallyIndexedArray<typename GridwiseGemm::DefaultBlock2CTileMap,
GroupCount>>, MaxGroupCount>>,
false>; false,
MaxGroupCount>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -409,6 +438,8 @@ struct DeviceGroupedGemmXdl ...@@ -409,6 +438,8 @@ struct DeviceGroupedGemmXdl
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
gemm_shapes,
arg.gemm_shapes_.size(),
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
...@@ -434,7 +465,6 @@ struct DeviceGroupedGemmXdl ...@@ -434,7 +465,6 @@ struct DeviceGroupedGemmXdl
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
const index_t i = 0;
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_[Number<0>{}], return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_[Number<0>{}],
arg.b_grid_desc_k0_n_k1_[Number<0>{}], arg.b_grid_desc_k0_n_k1_[Number<0>{}],
arg.c_grid_desc_m_n_[Number<0>{}], arg.c_grid_desc_m_n_[Number<0>{}],
......
...@@ -76,7 +76,7 @@ int main(int argc, char* argv[]) ...@@ -76,7 +76,7 @@ int main(int argc, char* argv[])
exit(0); exit(0);
} }
int group_count = 1; int group_count = 2;
// GEMM shape // GEMM shape
std::vector<ck::gemm_desc> gemm_shapes; std::vector<ck::gemm_desc> gemm_shapes;
...@@ -85,11 +85,11 @@ int main(int argc, char* argv[]) ...@@ -85,11 +85,11 @@ int main(int argc, char* argv[])
for(int i = 0; i < group_count; i++) for(int i = 0; i < group_count; i++)
{ {
int M = 256; int M = 256 * (i + 1);
int N = 512; int N = 512 * (i + 1);
int K = 1024; int K = 1024 * (i + 1);
gemm_shapes.push_back({M, N, K, K, K, N, A_size, B_size, C_size}); gemm_shapes.push_back({M, N, K, K, K, N, A_size, B_size, C_size, 0, 0});
A_size += M * K; A_size += M * K;
B_size += N * K; B_size += N * K;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment