Commit 060d171f authored by ozturkosu's avatar ozturkosu
Browse files

changed StreamKReductionStrategy from Atomic to Reduction as hardcoded

parent 3c7fef7f
...@@ -1010,10 +1010,16 @@ enum StreamKReductionStrategy ...@@ -1010,10 +1010,16 @@ enum StreamKReductionStrategy
Reduction, // let some workgroup responsible for doing the reduction operation Reduction, // let some workgroup responsible for doing the reduction operation
}; };
// template <uint32_t MPerBlock_,
// uint32_t NPerBlock_,
// uint32_t KPerBlock_,
// StreamKReductionStrategy ReductionStrategy_ = StreamKReductionStrategy::Atomic,
// uint32_t TileSwizzleSubM_ = 8>
template <uint32_t MPerBlock_, template <uint32_t MPerBlock_,
uint32_t NPerBlock_, uint32_t NPerBlock_,
uint32_t KPerBlock_, uint32_t KPerBlock_,
StreamKReductionStrategy ReductionStrategy_ = StreamKReductionStrategy::Atomic, StreamKReductionStrategy ReductionStrategy_ = StreamKReductionStrategy::Reduction,
uint32_t TileSwizzleSubM_ = 8> uint32_t TileSwizzleSubM_ = 8>
struct BlockToCTileMap_GemmStreamK struct BlockToCTileMap_GemmStreamK
{ {
......
...@@ -539,10 +539,18 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -539,10 +539,18 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
const ADataType* p_a_grid; const ADataType* p_a_grid;
const BDataType* p_b_grid; const BDataType* p_b_grid;
CDataType* p_c_grid; CDataType* p_c_grid;
// BlockToCTileMap_GemmStreamK_v2<MPerBlock,
// NPerBlock,
// KPerBlock,
// StreamKReductionStrategy::Atomic,
// 8,
// 4>
// block_2_ctile_map_streamk;
BlockToCTileMap_GemmStreamK_v2<MPerBlock, BlockToCTileMap_GemmStreamK_v2<MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
StreamKReductionStrategy::Atomic, StreamKReductionStrategy::Reduction,
8, 8,
4> 4>
block_2_ctile_map_streamk; block_2_ctile_map_streamk;
...@@ -1176,10 +1184,18 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1176,10 +1184,18 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}(); }();
return c_partial_acc_block_m_n; return c_partial_acc_block_m_n;
} }
// using Block2CTileMap_streamk = BlockToCTileMap_GemmStreamK_v2<MPerBlock,
// NPerBlock,
// KPerBlock,
// StreamKReductionStrategy::Atomic,
// 8,
// 4>;
using Block2CTileMap_streamk = BlockToCTileMap_GemmStreamK_v2<MPerBlock, using Block2CTileMap_streamk = BlockToCTileMap_GemmStreamK_v2<MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
StreamKReductionStrategy::Atomic, StreamKReductionStrategy::Reduction,
8, 8,
4>; 4>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment