Commit bbe5c0c7 authored by Jing Zhang's avatar Jing Zhang
Browse files

2 gemm test

parent 6cbb0a13
......@@ -176,6 +176,7 @@ struct gemm_desc
ck::index_t M, N, K;
ck::index_t StrideA, StrideB, StrideC;
ck::index_t OffsetA, OffsetB, OffsetC;
ck::index_t BlockStart, BlockSize;
};
} // namespace ck
......
......@@ -18,11 +18,13 @@ template <typename GridwiseGemm,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename GemmDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap,
bool HasMainK0BlockLoop>
bool HasMainK0BlockLoop,
index_t MaxGroupCount>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
......@@ -34,6 +36,8 @@ __global__ void
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const GemmDesc gemm_shapes,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
......@@ -43,24 +47,67 @@ __global__ void
const index_t block_id = get_block_1d_id();
const index_t group_id = 0;
index_t group_id = 0;
index_t block_id_grp = 0;
index_t a_offset_grp = 0;
index_t b_offset_grp = 0;
index_t c_offset_grp = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(i < group_count)
{
if(block_id >= gemm_shapes[i].BlockStart &&
block_id < (gemm_shapes[i].BlockStart + gemm_shapes[i].BlockSize))
{
group_id = i;
block_id_grp = block_id - gemm_shapes[i].BlockStart;
a_offset_grp = gemm_shapes[i].OffsetA;
b_offset_grp = gemm_shapes[i].OffsetB;
c_offset_grp = gemm_shapes[i].OffsetC;
// if(get_thread_local_1d_id() == 0)
// printf("%d %d %d %d %d %d\n",
// block_id,
// group_id,
// block_id_grp,
// a_offset_grp,
// b_offset_grp,
// c_offset_grp);
}
}
});
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
if(group_id == 0)
GridwiseGemm::template Run<HasMainK0BlockLoop>(
p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
a_grid_desc_k0_m_k1[Number<0>{}],
b_grid_desc_k0_n_k1[Number<0>{}],
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2[Number<0>{}],
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map[Number<0>{}],
block_id);
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_offset_grp,
p_b_grid + b_offset_grp,
p_c_grid + c_offset_grp,
p_shared,
a_grid_desc_k0_m_k1[I0],
b_grid_desc_k0_n_k1[I0],
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2[I0],
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map[I0],
block_id_grp,
group_id);
else if(group_id == 1)
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_offset_grp,
p_b_grid + b_offset_grp,
p_c_grid + c_offset_grp,
p_shared,
a_grid_desc_k0_m_k1[I1],
b_grid_desc_k0_n_k1[I1],
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2[I1],
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map[I1],
block_id_grp,
group_id);
}
template <index_t BlockSize,
......@@ -360,7 +407,8 @@ struct GridwiseGroupedGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map,
const index_t block_id)
const index_t block_id,
const index_t group_id)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
......@@ -382,6 +430,14 @@ struct GridwiseGroupedGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// if(get_thread_local_1d_id() == 0)
//{
// printf("m: %d n: %d k: %d\n", a_grid_desc_k0_m_k1.GetLength(I1),
// b_grid_desc_k0_n_k1.GetLength(I1), a_grid_desc_k0_m_k1.GetLength(I0));
// printf("block_work_idx: %d %d %d %d\n", group_id, block_id, block_work_idx[I0],
// block_work_idx[I1]);
//}
// lds max alignment
constexpr auto max_lds_align = K1;
......
......@@ -53,8 +53,8 @@ template <typename ADataType,
bool BBlockLdsAddExtraN,
ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector,
ck::index_t NumPrefetch = 1,
ck::index_t GroupCount = 1>
ck::index_t NumPrefetch = 1,
ck::index_t MaxGroupCount = 5>
struct DeviceGroupedGemmXdl
: public DeviceGroupedGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{
......@@ -238,55 +238,62 @@ struct DeviceGroupedGemmXdl
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
gemm_shapes_{gemm_shapes},
M01_{M01},
N01_{N01},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
const index_t i = 0;
const index_t M = gemm_shapes[Number<0>{}].M;
const index_t N = gemm_shapes[Number<0>{}].N;
const index_t K = gemm_shapes[Number<0>{}].K;
const index_t StrideA = gemm_shapes[Number<0>{}].StrideA;
const index_t StrideB = gemm_shapes[Number<0>{}].StrideB;
const index_t StrideC = gemm_shapes[Number<0>{}].StrideC;
a_grid_desc_k0_m_k1_(Number<0>{}) =
DeviceGroupedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_(Number<0>{}) =
DeviceGroupedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_(Number<0>{}) =
DeviceGroupedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC);
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_[Number<0>{}],
b_grid_desc_k0_n_k1_[Number<0>{}],
c_grid_desc_m_n_[Number<0>{}],
M01_,
N01_))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_(Number<0>{}) =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
c_grid_desc_m_n_[Number<0>{}]);
block_2_ctile_map_(Number<0>{}) = GridwiseGemm::MakeDefaultBlock2CTileMap(
c_grid_desc_m_n_[Number<0>{}], M01, N01);
}
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(i < gemm_shapes_.size())
{
const index_t M = gemm_shapes_[i].M;
const index_t N = gemm_shapes_[i].N;
const index_t K = gemm_shapes_[i].K;
const index_t StrideA = gemm_shapes_[i].StrideA;
const index_t StrideB = gemm_shapes_[i].StrideB;
const index_t StrideC = gemm_shapes_[i].StrideC;
a_grid_desc_k0_m_k1_(i) =
DeviceGroupedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_(i) =
DeviceGroupedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_(i) =
DeviceGroupedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC);
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_[i],
b_grid_desc_k0_n_k1_[i],
c_grid_desc_m_n_[i],
M01_,
N01_))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_(i) =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
c_grid_desc_m_n_[i]);
block_2_ctile_map_(i) =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_[i], M01, N01);
}
}
});
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
StaticallyIndexedArray<AGridDesc_K0_M_K1, GroupCount> a_grid_desc_k0_m_k1_;
StaticallyIndexedArray<BGridDesc_K0_N_K1, GroupCount> b_grid_desc_k0_n_k1_;
StaticallyIndexedArray<CGridDesc_M_N, GroupCount> c_grid_desc_m_n_;
StaticallyIndexedArray<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2, GroupCount>
StaticallyIndexedArray<AGridDesc_K0_M_K1, MaxGroupCount> a_grid_desc_k0_m_k1_;
StaticallyIndexedArray<BGridDesc_K0_N_K1, MaxGroupCount> b_grid_desc_k0_n_k1_;
StaticallyIndexedArray<CGridDesc_M_N, MaxGroupCount> c_grid_desc_m_n_;
StaticallyIndexedArray<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
MaxGroupCount>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
StaticallyIndexedArray<typename GridwiseGemm::DefaultBlock2CTileMap, GroupCount>
StaticallyIndexedArray<typename GridwiseGemm::DefaultBlock2CTileMap, MaxGroupCount>
block_2_ctile_map_;
std::vector<gemm_desc> gemm_shapes_;
index_t M01_;
index_t N01_;
AElementwiseOperation a_element_op_;
......@@ -301,33 +308,49 @@ struct DeviceGroupedGemmXdl
float Run(const Argument& arg, int nrepeat = 1)
{
std::cout << "Gemm " << i << ":" << std::endl;
std::cout << "arg.a_grid_desc_k0_m_k1_{"
<< arg.a_grid_desc_k0_m_k1_[Number<0>{}].GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_[Number<0>{}].GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_[Number<0>{}].GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{"
<< arg.b_grid_desc_k0_n_k1_[Number<0>{}].GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_[Number<0>{}].GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_[Number<0>{}].GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_[Number<0>{}].GetLength(I0)
<< ", " << arg.c_grid_desc_m_n_[Number<0>{}].GetLength(I1) << "}"
<< std::endl;
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_[Number<0>{}],
arg.b_grid_desc_k0_n_k1_[Number<0>{}],
arg.c_grid_desc_m_n_[Number<0>{}],
arg.M01_,
arg.N01_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_[Number<0>{}]);
StaticallyIndexedArray<gemm_desc, MaxGroupCount> gemm_shapes;
index_t grid_size = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(i < arg.gemm_shapes_.size())
{
std::cout << "arg.a_grid_desc_k0_m_k1_{"
<< arg.a_grid_desc_k0_m_k1_[i].GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_[i].GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_[i].GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{"
<< arg.b_grid_desc_k0_n_k1_[i].GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_[i].GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_[i].GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_[i].GetLength(I0)
<< ", " << arg.c_grid_desc_m_n_[i].GetLength(I1) << "}" << std::endl;
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_[i],
arg.b_grid_desc_k0_n_k1_[i],
arg.c_grid_desc_m_n_[i],
arg.M01_,
arg.N01_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
const index_t grid_size_grp =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_[i]);
gemm_shapes(i) = arg.gemm_shapes_[i];
gemm_shapes(i).BlockStart = grid_size;
gemm_shapes(i).BlockSize = grid_size_grp;
grid_size += grid_size_grp;
std::cout << "group_id " << i << " BlockStart " << gemm_shapes(i).BlockStart
<< " BlockSize " << gemm_shapes(i).BlockSize << std::endl;
}
});
const auto K0 = arg.a_grid_desc_k0_m_k1_[Number<0>{}].GetLength(I0);
......@@ -344,20 +367,22 @@ struct DeviceGroupedGemmXdl
CDataType,
remove_reference_t<
StaticallyIndexedArray<DeviceGroupedGemmXdl::AGridDesc_K0_M_K1,
GroupCount>>,
MaxGroupCount>>,
remove_reference_t<
StaticallyIndexedArray<DeviceGroupedGemmXdl::BGridDesc_K0_N_K1,
GroupCount>>,
MaxGroupCount>>,
remove_reference_t<StaticallyIndexedArray<
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
GroupCount>>,
MaxGroupCount>>,
remove_reference_t<StaticallyIndexedArray<gemm_desc, MaxGroupCount>>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<
StaticallyIndexedArray<typename GridwiseGemm::DefaultBlock2CTileMap,
GroupCount>>,
true>;
MaxGroupCount>>,
true,
MaxGroupCount>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
......@@ -370,6 +395,8 @@ struct DeviceGroupedGemmXdl
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
gemm_shapes,
arg.gemm_shapes_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
......@@ -383,20 +410,22 @@ struct DeviceGroupedGemmXdl
CDataType,
remove_reference_t<
StaticallyIndexedArray<DeviceGroupedGemmXdl::AGridDesc_K0_M_K1,
GroupCount>>,
MaxGroupCount>>,
remove_reference_t<
StaticallyIndexedArray<DeviceGroupedGemmXdl::BGridDesc_K0_N_K1,
GroupCount>>,
MaxGroupCount>>,
remove_reference_t<StaticallyIndexedArray<
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
GroupCount>>,
MaxGroupCount>>,
remove_reference_t<StaticallyIndexedArray<gemm_desc, MaxGroupCount>>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<
StaticallyIndexedArray<typename GridwiseGemm::DefaultBlock2CTileMap,
GroupCount>>,
false>;
MaxGroupCount>>,
false,
MaxGroupCount>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
......@@ -409,6 +438,8 @@ struct DeviceGroupedGemmXdl
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
gemm_shapes,
arg.gemm_shapes_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
......@@ -434,7 +465,6 @@ struct DeviceGroupedGemmXdl
static bool IsSupportedArgument(const Argument& arg)
{
const index_t i = 0;
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_[Number<0>{}],
arg.b_grid_desc_k0_n_k1_[Number<0>{}],
arg.c_grid_desc_m_n_[Number<0>{}],
......
......@@ -76,7 +76,7 @@ int main(int argc, char* argv[])
exit(0);
}
int group_count = 1;
int group_count = 2;
// GEMM shape
std::vector<ck::gemm_desc> gemm_shapes;
......@@ -85,11 +85,11 @@ int main(int argc, char* argv[])
for(int i = 0; i < group_count; i++)
{
int M = 256;
int N = 512;
int K = 1024;
int M = 256 * (i + 1);
int N = 512 * (i + 1);
int K = 1024 * (i + 1);
gemm_shapes.push_back({M, N, K, K, K, N, A_size, B_size, C_size});
gemm_shapes.push_back({M, N, K, K, K, N, A_size, B_size, C_size, 0, 0});
A_size += M * K;
B_size += N * K;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment