Commit 55ab4687 authored by Jing Zhang's avatar Jing Zhang
Browse files

perf test

parent bbe5c0c7
......@@ -29,19 +29,20 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_v2r3(
kernel_grouped_gemm_xdlops_v2r3(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const GemmDesc gemm_shapes,
const StaticallyIndexedArray<AGridDesc_K0_M_K1, MaxGroupCount> a_grid_desc_k0_m_k1,
const StaticallyIndexedArray<BGridDesc_K0_N_K1, MaxGroupCount> b_grid_desc_k0_n_k1,
const StaticallyIndexedArray<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2, MaxGroupCount>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_shapes,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
const StaticallyIndexedArray<Block2CTileMap, MaxGroupCount> block_2_ctile_map)
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -54,60 +55,122 @@ __global__ void
index_t c_offset_grp = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(i < group_count)
if(block_id >= gemm_shapes[i].BlockStart &&
block_id < (gemm_shapes[i].BlockStart + gemm_shapes[i].BlockSize))
{
if(block_id >= gemm_shapes[i].BlockStart &&
block_id < (gemm_shapes[i].BlockStart + gemm_shapes[i].BlockSize))
{
group_id = i;
block_id_grp = block_id - gemm_shapes[i].BlockStart;
a_offset_grp = gemm_shapes[i].OffsetA;
b_offset_grp = gemm_shapes[i].OffsetB;
c_offset_grp = gemm_shapes[i].OffsetC;
// if(get_thread_local_1d_id() == 0)
// printf("%d %d %d %d %d %d\n",
// block_id,
// group_id,
// block_id_grp,
// a_offset_grp,
// b_offset_grp,
// c_offset_grp);
}
group_id = i;
block_id_grp = block_id - gemm_shapes[i].BlockStart;
a_offset_grp = gemm_shapes[i].OffsetA;
b_offset_grp = gemm_shapes[i].OffsetB;
c_offset_grp = gemm_shapes[i].OffsetC;
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_offset_grp,
p_b_grid + b_offset_grp,
p_c_grid + c_offset_grp,
p_shared,
a_grid_desc_k0_m_k1[i],
b_grid_desc_k0_n_k1[i],
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2[i],
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map[i],
block_id_grp);
// if(get_thread_local_1d_id() == 0)
// printf("%d %d %d %d %d %d\n",
// block_id,
// group_id,
// block_id_grp,
// a_offset_grp,
// b_offset_grp,
// c_offset_grp);
}
});
}
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename GemmDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap,
bool HasMainK0BlockLoop,
index_t MaxGroupCount>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdlops_v2r4(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const StaticallyIndexedArray<AGridDesc_K0_M_K1, MaxGroupCount> a_grid_desc_k0_m_k1,
const StaticallyIndexedArray<BGridDesc_K0_N_K1, MaxGroupCount> b_grid_desc_k0_n_k1,
const StaticallyIndexedArray<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2, MaxGroupCount>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_shapes,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const StaticallyIndexedArray<Block2CTileMap, MaxGroupCount> block_2_ctile_map)
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
__shared__ AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_[MaxGroupCount];
__shared__ BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_[MaxGroupCount];
__shared__ CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_[MaxGroupCount];
__shared__ Block2CTileMap block_2_ctile_map_[MaxGroupCount];
__shared__ GemmDesc gemm_shapes_[MaxGroupCount];
if(get_thread_local_1d_id())
{
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
a_grid_desc_k0_m_k1_[i] = a_grid_desc_k0_m_k1[i];
b_grid_desc_k0_n_k1_[i] = b_grid_desc_k0_n_k1[i];
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_[i] = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2[i];
block_2_ctile_map_[i] = block_2_ctile_map[i];
gemm_shapes_[i] = gemm_shapes[i];
});
}
block_sync_lds();
index_t group_id = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
group_id = (block_id >= gemm_shapes[i].BlockStart &&
block_id < (gemm_shapes[i].BlockStart + gemm_shapes[i].BlockSize))
? i
: group_id;
});
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
if(group_id == 0)
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_offset_grp,
p_b_grid + b_offset_grp,
p_c_grid + c_offset_grp,
p_shared,
a_grid_desc_k0_m_k1[I0],
b_grid_desc_k0_n_k1[I0],
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2[I0],
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map[I0],
block_id_grp,
group_id);
else if(group_id == 1)
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_offset_grp,
p_b_grid + b_offset_grp,
p_c_grid + c_offset_grp,
p_shared,
a_grid_desc_k0_m_k1[I1],
b_grid_desc_k0_n_k1[I1],
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2[I1],
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map[I1],
block_id_grp,
group_id);
const index_t block_id_grp = block_id - gemm_shapes_[group_id].BlockStart;
const index_t a_offset_grp = gemm_shapes_[group_id].OffsetA;
const index_t b_offset_grp = gemm_shapes_[group_id].OffsetB;
const index_t c_offset_grp = gemm_shapes_[group_id].OffsetC;
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_offset_grp,
p_b_grid + b_offset_grp,
p_c_grid + c_offset_grp,
p_shared,
a_grid_desc_k0_m_k1_[group_id],
b_grid_desc_k0_n_k1_[group_id],
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_[group_id],
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map_[group_id],
block_id_grp);
}
template <index_t BlockSize,
......@@ -407,8 +470,7 @@ struct GridwiseGroupedGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map,
const index_t block_id,
const index_t group_id)
const index_t block_id)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
......
......@@ -350,6 +350,11 @@ struct DeviceGroupedGemmXdl
std::cout << "group_id " << i << " BlockStart " << gemm_shapes(i).BlockStart
<< " BlockSize " << gemm_shapes(i).BlockSize << std::endl;
}
else
{
gemm_shapes(i).BlockStart = -1;
gemm_shapes(i).BlockSize = -1;
}
});
const auto K0 = arg.a_grid_desc_k0_m_k1_[Number<0>{}].GetLength(I0);
......@@ -361,26 +366,18 @@ struct DeviceGroupedGemmXdl
#if 1
if(has_main_k0_block_loop)
{
const auto kernel = kernel_gemm_xdlops_v2r3<
const auto kernel = kernel_grouped_gemm_xdlops_v2r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<
StaticallyIndexedArray<DeviceGroupedGemmXdl::AGridDesc_K0_M_K1,
MaxGroupCount>>,
remove_reference_t<
StaticallyIndexedArray<DeviceGroupedGemmXdl::BGridDesc_K0_N_K1,
MaxGroupCount>>,
remove_reference_t<StaticallyIndexedArray<
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
MaxGroupCount>>,
remove_reference_t<StaticallyIndexedArray<gemm_desc, MaxGroupCount>>,
remove_reference_t<DeviceGroupedGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGroupedGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<gemm_desc>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<
StaticallyIndexedArray<typename GridwiseGemm::DefaultBlock2CTileMap,
MaxGroupCount>>,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true,
MaxGroupCount>;
......@@ -404,26 +401,18 @@ struct DeviceGroupedGemmXdl
}
else
{
const auto kernel = kernel_gemm_xdlops_v2r3<
const auto kernel = kernel_grouped_gemm_xdlops_v2r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<
StaticallyIndexedArray<DeviceGroupedGemmXdl::AGridDesc_K0_M_K1,
MaxGroupCount>>,
remove_reference_t<
StaticallyIndexedArray<DeviceGroupedGemmXdl::BGridDesc_K0_N_K1,
MaxGroupCount>>,
remove_reference_t<StaticallyIndexedArray<
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
MaxGroupCount>>,
remove_reference_t<StaticallyIndexedArray<gemm_desc, MaxGroupCount>>,
remove_reference_t<DeviceGroupedGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGroupedGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<gemm_desc>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<
StaticallyIndexedArray<typename GridwiseGemm::DefaultBlock2CTileMap,
MaxGroupCount>>,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false,
MaxGroupCount>;
......
......@@ -76,7 +76,7 @@ int main(int argc, char* argv[])
exit(0);
}
int group_count = 2;
int group_count = 3;
// GEMM shape
std::vector<ck::gemm_desc> gemm_shapes;
......@@ -85,11 +85,11 @@ int main(int argc, char* argv[])
for(int i = 0; i < group_count; i++)
{
int M = 256 * (i + 1);
int N = 512 * (i + 1);
int K = 1024 * (i + 1);
int M = 2048 + 256 * i;
int N = 2048 + 128 * i;
int K = 256 + 128 * i;
gemm_shapes.push_back({M, N, K, K, K, N, A_size, B_size, C_size, 0, 0});
gemm_shapes.push_back({M, N, K, K, K, N, A_size, B_size, C_size, -1, -1});
A_size += M * K;
B_size += N * K;
......@@ -115,6 +115,8 @@ int main(int argc, char* argv[])
std::vector<Tensor<CDataType>> c_host_tensors;
std::vector<Tensor<CDataType>> c_device_tensors;
std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < gemm_shapes.size(); i++)
{
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
......@@ -129,6 +131,11 @@ int main(int argc, char* argv[])
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc
<< std::endl;
flop += std::size_t(2) * gemm_shapes[i].M * gemm_shapes[i].K * gemm_shapes[i].N;
num_btype += sizeof(ADataType) * gemm_shapes[i].M * gemm_shapes[i].K +
sizeof(BDataType) * gemm_shapes[i].K * gemm_shapes[i].N +
sizeof(CDataType) * gemm_shapes[i].M * gemm_shapes[i].N;
}
for(int i = 0; i < gemm_shapes.size(); i++)
......@@ -192,6 +199,13 @@ int main(int argc, char* argv[])
float ave_time = invoker.Run(argument, nrepeat);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
c_tensors_data.resize(C_size);
c_tensors_device_buf.FromDevice(c_tensors_data.data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment