Unverified Commit 14b3504d authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Update GetTypeString function to generate unique kernel IDs. (#638)

* make conv_fwd_bias_activation kernel id unique

* add more parameters to conv and gemm kernel names

* update GetTypeString for conv and gemm kernels

* fix two more kernel strings
parent ea028ac6
......@@ -640,7 +640,16 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock
<< K0PerBlock << ", "
<< K1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1
<< ">";
// clang-format on
......
......@@ -1003,7 +1003,15 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization)
<< getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">";
return str.str();
......
......@@ -1203,7 +1203,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization)
<< getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", "
<< K1
<< ">";
// clang-format on
......
......@@ -1231,7 +1231,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization)
<< getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", "
<< K1 << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1 << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< CBlockTransferScalarPerVector_NWaveNPerXdl
<< ">";
// clang-format on
......
......@@ -1092,7 +1092,15 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization)
<< getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">";
// clang-format on
......
......@@ -838,7 +838,19 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization)
<< getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< K1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1 << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< CBlockTransferScalarPerVector_NWaveNPerXdl
<< ">";
// clang-format on
......
......@@ -939,7 +939,15 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization)
<< getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">";
// clang-format on
......
......@@ -688,6 +688,10 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec)
<< ">";
// clang-format on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment