Commit b1727e6b authored by rocking's avatar rocking
Browse files

Extract template parameter

parent 876acde3
...@@ -304,6 +304,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -304,6 +304,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
ReduceThreadTransferClusterLengths_MPerBlock_NPerBlock, ReduceThreadTransferClusterLengths_MPerBlock_NPerBlock,
ReduceThreadTransferScalarPerVector_NPerBlock, ReduceThreadTransferScalarPerVector_NPerBlock,
1,
LoopSched>; LoopSched>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
......
...@@ -80,6 +80,7 @@ template <typename ABDataType, ...@@ -80,6 +80,7 @@ template <typename ABDataType,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, typename CDRThreadTransferClusterLengths_MPerBlock_NPerBlock,
index_t CDEReduceThreadTransferScalarPerVector_NPerBlock, index_t CDEReduceThreadTransferScalarPerVector_NPerBlock,
index_t FGTransferScalarPerVector,
LoopScheduler LoopSched> LoopScheduler LoopSched>
struct GridwiseGemmMultipleDWelford_xdl_cshuffle struct GridwiseGemmMultipleDWelford_xdl_cshuffle
{ {
...@@ -983,11 +984,10 @@ struct GridwiseGemmMultipleDWelford_xdl_cshuffle ...@@ -983,11 +984,10 @@ struct GridwiseGemmMultipleDWelford_xdl_cshuffle
constexpr auto thread_welford_desc_I_m_I = make_naive_tensor_descriptor_packed( constexpr auto thread_welford_desc_I_m_I = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<mreduce_per_thread>{}, I1)); make_tuple(I1, Number<mreduce_per_thread>{}, I1));
// TODO - extract template parameter
constexpr int scalarPerVector = 1;
constexpr int shuffleMPerBlock = constexpr int shuffleMPerBlock =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1);
static_assert(mreduce_per_thread % FGTransferScalarPerVector == 0);
auto f_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto f_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType, AccDataType,
FDataType, FDataType,
...@@ -997,7 +997,7 @@ struct GridwiseGemmMultipleDWelford_xdl_cshuffle ...@@ -997,7 +997,7 @@ struct GridwiseGemmMultipleDWelford_xdl_cshuffle
Sequence<1, mreduce_per_thread, 1>, Sequence<1, mreduce_per_thread, 1>,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
1, 1,
scalarPerVector, FGTransferScalarPerVector,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>{f_grid_desc_mblock_mperblock_nblock, false>{f_grid_desc_mblock_mperblock_nblock,
...@@ -1016,7 +1016,7 @@ struct GridwiseGemmMultipleDWelford_xdl_cshuffle ...@@ -1016,7 +1016,7 @@ struct GridwiseGemmMultipleDWelford_xdl_cshuffle
Sequence<1, mreduce_per_thread, 1>, Sequence<1, mreduce_per_thread, 1>,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
1, 1,
scalarPerVector, FGTransferScalarPerVector,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>{g_grid_desc_mblock_mperblock_nblock, false>{g_grid_desc_mblock_mperblock_nblock,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment