Commit 40bcfcde authored by rocking's avatar rocking
Browse files

Evaluate perf of the kernel

parent 9402ee4b
...@@ -84,23 +84,23 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp ...@@ -84,23 +84,23 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp
using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize; using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize;
// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y // A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y
using DeviceNormalizeInstance = ck::tensor_operation::device::Device5AryElementwise< using DeviceNormalizeInstance =
CDataType, ck::tensor_operation::device::Device5AryElementwise<CDataType,
DDataType, DDataType,
DDataType, DDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
LayerNormOutDataType, LayerNormOutDataType,
NormalizeComputeDataType, NormalizeComputeDataType,
NormalizeFunctor, NormalizeFunctor,
2, 2,
8, 8,
8, // scalarPerVector: gemm_out 8, // scalarPerVector: gemm_out
1, // scalarPerVector: reduce_mean 1, // scalarPerVector: reduce_mean
1, // scalarPerVector: reduce_mean_square 1, // scalarPerVector: reduce_mean_square
8, // scalarPerVector: Gamma 8, // scalarPerVector: Gamma
8, // scalarPerVector: Beta 8, // scalarPerVector: Beta
8>; // scalarPerVector: LayerNorm_out 8>; // scalarPerVector: LayerNorm_out
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({len}), return HostTensorDescriptor(std::vector<std::size_t>({len}),
...@@ -189,10 +189,37 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, ...@@ -189,10 +189,37 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
} }
} }
int main() template <typename ADataType,
typename BDataType,
typename CDataType,
typename DDataType,
typename GammaDataType,
typename BetaDataType,
typename NormalizeDataType>
void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M, int N, int K)
{ {
bool time_kernel = false; std::size_t gemm_flop = std::size_t(2) * M * N * K;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(DDataType) * M +
sizeof(DDataType) * M;
std::size_t normalize_num_btye = sizeof(CDataType) * M * N + sizeof(DDataType) * M +
sizeof(DDataType) * M + sizeof(GammaDataType) * N +
sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time;
float normalize_gb_per_sec = normalize_num_btye / 1.E6 / normalize_time;
std::cout << "gemm + reduce_mean + reduce_square_mean Perf: " << gemm_reduce_time << " ms, "
<< tflops << " TFlops, " << gemm_gb_per_sec << " GB/s, " << std::endl;
std::cout << "gemm + reduce_mean + reduce_square_mean Perf: " << normalize_time << " ms, "
<< normalize_gb_per_sec << " GB/s, " << std::endl;
}
int main()
{
// GEMM shape // GEMM shape
ck::index_t M = 1024; ck::index_t M = 1024;
ck::index_t N = 1024; ck::index_t N = 1024;
...@@ -299,8 +326,8 @@ int main() ...@@ -299,8 +326,8 @@ int main()
} }
// run kernel // run kernel
gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel}); gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, false});
normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, time_kernel}); normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, false});
bool pass = true; bool pass = true;
{ {
...@@ -327,5 +354,25 @@ int main() ...@@ -327,5 +354,25 @@ int main()
1e-3); 1e-3);
} }
{
// evaluate kernel perf
time_kernel = true;
float gemm_reduce_mean_reduce_square_mean_ave_time =
gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel});
float normalize_ave_time =
normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, time_kernel});
if(time_kernel)
DumpGemmLayerNormPerf<ADataType,
BDataType,
CDataType,
DDataType,
GammaDataType,
BetaDataType,
LayerNormOutDataType>(
gemm_reduce_mean_reduce_square_mean_ave_time, normalize_ave_time, M, N, K);
}
return pass ? 0 : 1; return pass ? 0 : 1;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment