Commit 55ab4687 authored by Jing Zhang's avatar Jing Zhang
Browse files

perf test

parent bbe5c0c7
...@@ -29,19 +29,20 @@ __global__ void ...@@ -29,19 +29,20 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdlops_v2r3( kernel_grouped_gemm_xdlops_v2r3(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const StaticallyIndexedArray<AGridDesc_K0_M_K1, MaxGroupCount> a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const StaticallyIndexedArray<BGridDesc_K0_N_K1, MaxGroupCount> b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const StaticallyIndexedArray<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2, MaxGroupCount>
const GemmDesc gemm_shapes, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_shapes,
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map) const StaticallyIndexedArray<Block2CTileMap, MaxGroupCount> block_2_ctile_map)
{ {
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -54,8 +55,6 @@ __global__ void ...@@ -54,8 +55,6 @@ __global__ void
index_t c_offset_grp = 0; index_t c_offset_grp = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) { static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(i < group_count)
{
if(block_id >= gemm_shapes[i].BlockStart && if(block_id >= gemm_shapes[i].BlockStart &&
block_id < (gemm_shapes[i].BlockStart + gemm_shapes[i].BlockSize)) block_id < (gemm_shapes[i].BlockStart + gemm_shapes[i].BlockSize))
{ {
...@@ -65,6 +64,19 @@ __global__ void ...@@ -65,6 +64,19 @@ __global__ void
b_offset_grp = gemm_shapes[i].OffsetB; b_offset_grp = gemm_shapes[i].OffsetB;
c_offset_grp = gemm_shapes[i].OffsetC; c_offset_grp = gemm_shapes[i].OffsetC;
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_offset_grp,
p_b_grid + b_offset_grp,
p_c_grid + c_offset_grp,
p_shared,
a_grid_desc_k0_m_k1[i],
b_grid_desc_k0_n_k1[i],
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2[i],
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map[i],
block_id_grp);
// if(get_thread_local_1d_id() == 0) // if(get_thread_local_1d_id() == 0)
// printf("%d %d %d %d %d %d\n", // printf("%d %d %d %d %d %d\n",
// block_id, // block_id,
...@@ -74,40 +86,91 @@ __global__ void ...@@ -74,40 +86,91 @@ __global__ void
// b_offset_grp, // b_offset_grp,
// c_offset_grp); // c_offset_grp);
} }
});
}
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename GemmDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap,
bool HasMainK0BlockLoop,
index_t MaxGroupCount>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdlops_v2r4(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const StaticallyIndexedArray<AGridDesc_K0_M_K1, MaxGroupCount> a_grid_desc_k0_m_k1,
const StaticallyIndexedArray<BGridDesc_K0_N_K1, MaxGroupCount> b_grid_desc_k0_n_k1,
const StaticallyIndexedArray<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2, MaxGroupCount>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_shapes,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const StaticallyIndexedArray<Block2CTileMap, MaxGroupCount> block_2_ctile_map)
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
__shared__ AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_[MaxGroupCount];
__shared__ BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_[MaxGroupCount];
__shared__ CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_[MaxGroupCount];
__shared__ Block2CTileMap block_2_ctile_map_[MaxGroupCount];
__shared__ GemmDesc gemm_shapes_[MaxGroupCount];
if(get_thread_local_1d_id())
{
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
a_grid_desc_k0_m_k1_[i] = a_grid_desc_k0_m_k1[i];
b_grid_desc_k0_n_k1_[i] = b_grid_desc_k0_n_k1[i];
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_[i] = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2[i];
block_2_ctile_map_[i] = block_2_ctile_map[i];
gemm_shapes_[i] = gemm_shapes[i];
});
} }
block_sync_lds();
index_t group_id = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
group_id = (block_id >= gemm_shapes[i].BlockStart &&
block_id < (gemm_shapes[i].BlockStart + gemm_shapes[i].BlockSize))
? i
: group_id;
}); });
constexpr auto I0 = Number<0>{}; const index_t block_id_grp = block_id - gemm_shapes_[group_id].BlockStart;
constexpr auto I1 = Number<1>{}; const index_t a_offset_grp = gemm_shapes_[group_id].OffsetA;
const index_t b_offset_grp = gemm_shapes_[group_id].OffsetB;
const index_t c_offset_grp = gemm_shapes_[group_id].OffsetC;
if(group_id == 0)
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_offset_grp,
p_b_grid + b_offset_grp,
p_c_grid + c_offset_grp,
p_shared,
a_grid_desc_k0_m_k1[I0],
b_grid_desc_k0_n_k1[I0],
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2[I0],
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map[I0],
block_id_grp,
group_id);
else if(group_id == 1)
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_offset_grp, GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_offset_grp,
p_b_grid + b_offset_grp, p_b_grid + b_offset_grp,
p_c_grid + c_offset_grp, p_c_grid + c_offset_grp,
p_shared, p_shared,
a_grid_desc_k0_m_k1[I1], a_grid_desc_k0_m_k1_[group_id],
b_grid_desc_k0_n_k1[I1], b_grid_desc_k0_n_k1_[group_id],
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2[I1], c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_[group_id],
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
block_2_ctile_map[I1], block_2_ctile_map_[group_id],
block_id_grp, block_id_grp);
group_id);
} }
template <index_t BlockSize, template <index_t BlockSize,
...@@ -407,8 +470,7 @@ struct GridwiseGroupedGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -407,8 +470,7 @@ struct GridwiseGroupedGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map,
const index_t block_id, const index_t block_id)
const index_t group_id)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
......
...@@ -350,6 +350,11 @@ struct DeviceGroupedGemmXdl ...@@ -350,6 +350,11 @@ struct DeviceGroupedGemmXdl
std::cout << "group_id " << i << " BlockStart " << gemm_shapes(i).BlockStart std::cout << "group_id " << i << " BlockStart " << gemm_shapes(i).BlockStart
<< " BlockSize " << gemm_shapes(i).BlockSize << std::endl; << " BlockSize " << gemm_shapes(i).BlockSize << std::endl;
} }
else
{
gemm_shapes(i).BlockStart = -1;
gemm_shapes(i).BlockSize = -1;
}
}); });
const auto K0 = arg.a_grid_desc_k0_m_k1_[Number<0>{}].GetLength(I0); const auto K0 = arg.a_grid_desc_k0_m_k1_[Number<0>{}].GetLength(I0);
...@@ -361,26 +366,18 @@ struct DeviceGroupedGemmXdl ...@@ -361,26 +366,18 @@ struct DeviceGroupedGemmXdl
#if 1 #if 1
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
{ {
const auto kernel = kernel_gemm_xdlops_v2r3< const auto kernel = kernel_grouped_gemm_xdlops_v2r3<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t< remove_reference_t<DeviceGroupedGemmXdl::AGridDesc_K0_M_K1>,
StaticallyIndexedArray<DeviceGroupedGemmXdl::AGridDesc_K0_M_K1, remove_reference_t<DeviceGroupedGemmXdl::BGridDesc_K0_N_K1>,
MaxGroupCount>>, remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t< remove_reference_t<gemm_desc>,
StaticallyIndexedArray<DeviceGroupedGemmXdl::BGridDesc_K0_N_K1,
MaxGroupCount>>,
remove_reference_t<StaticallyIndexedArray<
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
MaxGroupCount>>,
remove_reference_t<StaticallyIndexedArray<gemm_desc, MaxGroupCount>>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
remove_reference_t< remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
StaticallyIndexedArray<typename GridwiseGemm::DefaultBlock2CTileMap,
MaxGroupCount>>,
true, true,
MaxGroupCount>; MaxGroupCount>;
...@@ -404,26 +401,18 @@ struct DeviceGroupedGemmXdl ...@@ -404,26 +401,18 @@ struct DeviceGroupedGemmXdl
} }
else else
{ {
const auto kernel = kernel_gemm_xdlops_v2r3< const auto kernel = kernel_grouped_gemm_xdlops_v2r3<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t< remove_reference_t<DeviceGroupedGemmXdl::AGridDesc_K0_M_K1>,
StaticallyIndexedArray<DeviceGroupedGemmXdl::AGridDesc_K0_M_K1, remove_reference_t<DeviceGroupedGemmXdl::BGridDesc_K0_N_K1>,
MaxGroupCount>>, remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t< remove_reference_t<gemm_desc>,
StaticallyIndexedArray<DeviceGroupedGemmXdl::BGridDesc_K0_N_K1,
MaxGroupCount>>,
remove_reference_t<StaticallyIndexedArray<
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
MaxGroupCount>>,
remove_reference_t<StaticallyIndexedArray<gemm_desc, MaxGroupCount>>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
remove_reference_t< remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
StaticallyIndexedArray<typename GridwiseGemm::DefaultBlock2CTileMap,
MaxGroupCount>>,
false, false,
MaxGroupCount>; MaxGroupCount>;
......
...@@ -76,7 +76,7 @@ int main(int argc, char* argv[]) ...@@ -76,7 +76,7 @@ int main(int argc, char* argv[])
exit(0); exit(0);
} }
int group_count = 2; int group_count = 3;
// GEMM shape // GEMM shape
std::vector<ck::gemm_desc> gemm_shapes; std::vector<ck::gemm_desc> gemm_shapes;
...@@ -85,11 +85,11 @@ int main(int argc, char* argv[]) ...@@ -85,11 +85,11 @@ int main(int argc, char* argv[])
for(int i = 0; i < group_count; i++) for(int i = 0; i < group_count; i++)
{ {
int M = 256 * (i + 1); int M = 2048 + 256 * i;
int N = 512 * (i + 1); int N = 2048 + 128 * i;
int K = 1024 * (i + 1); int K = 256 + 128 * i;
gemm_shapes.push_back({M, N, K, K, K, N, A_size, B_size, C_size, 0, 0}); gemm_shapes.push_back({M, N, K, K, K, N, A_size, B_size, C_size, -1, -1});
A_size += M * K; A_size += M * K;
B_size += N * K; B_size += N * K;
...@@ -115,6 +115,8 @@ int main(int argc, char* argv[]) ...@@ -115,6 +115,8 @@ int main(int argc, char* argv[])
std::vector<Tensor<CDataType>> c_host_tensors; std::vector<Tensor<CDataType>> c_host_tensors;
std::vector<Tensor<CDataType>> c_device_tensors; std::vector<Tensor<CDataType>> c_device_tensors;
std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < gemm_shapes.size(); i++) for(int i = 0; i < gemm_shapes.size(); i++)
{ {
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor( a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
...@@ -129,6 +131,11 @@ int main(int argc, char* argv[]) ...@@ -129,6 +131,11 @@ int main(int argc, char* argv[])
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc << " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc
<< std::endl; << std::endl;
flop += std::size_t(2) * gemm_shapes[i].M * gemm_shapes[i].K * gemm_shapes[i].N;
num_btype += sizeof(ADataType) * gemm_shapes[i].M * gemm_shapes[i].K +
sizeof(BDataType) * gemm_shapes[i].K * gemm_shapes[i].N +
sizeof(CDataType) * gemm_shapes[i].M * gemm_shapes[i].N;
} }
for(int i = 0; i < gemm_shapes.size(); i++) for(int i = 0; i < gemm_shapes.size(); i++)
...@@ -192,6 +199,13 @@ int main(int argc, char* argv[]) ...@@ -192,6 +199,13 @@ int main(int argc, char* argv[])
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, nrepeat);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
c_tensors_data.resize(C_size); c_tensors_data.resize(C_size);
c_tensors_device_buf.FromDevice(c_tensors_data.data()); c_tensors_device_buf.FromDevice(c_tensors_data.data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment