Commit 552053e2 authored by Paul's avatar Paul
Browse files

Format

parent 9e3825a2
...@@ -31,7 +31,7 @@ static std::string GetGemmSpec(const std::size_t m, ...@@ -31,7 +31,7 @@ static std::string GetGemmSpec(const std::size_t m,
} }
template <class F> template <class F>
std::vector<Operation_Xdl_CShuffle> CreateOperationsImpl(F f) std::vector<Operation_Xdl_CShuffle> CreateOperationsImpl(F f, Layout ALayout, Layout BLayout)
{ {
std::vector<Operation_Xdl_CShuffle> result; std::vector<Operation_Xdl_CShuffle> result;
// Tile Desc: (block_size, m_per_block, n_per_block, k_per_block, ak1, bk1, // Tile Desc: (block_size, m_per_block, n_per_block, k_per_block, ak1, bk1,
...@@ -54,36 +54,38 @@ std::vector<Operation_Xdl_CShuffle> CreateOperationsImpl(F f) ...@@ -54,36 +54,38 @@ std::vector<Operation_Xdl_CShuffle> CreateOperationsImpl(F f)
// BlockTransferDesc: (thread_cluster_length, thread_cluster_arrange_order, src_access_order, // BlockTransferDesc: (thread_cluster_length, thread_cluster_arrange_order, src_access_order,
// src_vec_dim, src_scalar_per_vector, dst_scalar_per_vector_k1, lds_add_extra_dim ) // src_vec_dim, src_scalar_per_vector, dst_scalar_per_vector_k1, lds_add_extra_dim )
auto ABlockTransferSrcVectorDim = ALayout == Layout::Column ? 1 : 2;
std::vector<operation::BlockTransferDesc> a_block_descriptions = { std::vector<operation::BlockTransferDesc> a_block_descriptions = {
{S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, {S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 8, 8, 1},
{S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, {S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 8, 8, 1},
{S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, {S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 8, 8, 1},
{S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, {S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 8, 8, 1},
{S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, {S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 8, 8, 1},
{S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, {S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 8, 8, 1},
{S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, {S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 8, 8, 1},
{S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, {S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 8, 8, 1},
{S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, {S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 8, 8, 1},
{S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, {S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 8, 8, 1},
{S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, {S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 8, 8, 1},
{S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, {S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 8, 8, 1},
{S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, {S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 8, 8, 1},
}; };
auto BBlockTransferSrcVectorDim = BLayout == Layout::Row ? 1 : 2;
std::vector<operation::BlockTransferDesc> b_block_descriptions = { std::vector<operation::BlockTransferDesc> b_block_descriptions = {
{S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, {S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 8, 8, 1},
{S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, {S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 8, 8, 1},
{S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, {S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 8, 8, 1},
{S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, {S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 8, 8, 1},
{S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, {S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 8, 8, 1},
{S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, {S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 8, 8, 1},
{S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, {S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 8, 8, 1},
{S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, {S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 8, 8, 1},
{S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, {S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 8, 8, 1},
{S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, {S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 8, 8, 1},
{S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, {S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 8, 8, 1},
{S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, {S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 8, 8, 1},
{S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, {S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 8, 8, 1},
}; };
// cshuffle_descriptions: (m_Xdl_per_wave_per_shuffle, n_Xdl_per_wave_per_shuffle) // cshuffle_descriptions: (m_Xdl_per_wave_per_shuffle, n_Xdl_per_wave_per_shuffle)
...@@ -144,7 +146,9 @@ static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row ...@@ -144,7 +146,9 @@ static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row
std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations() std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations()
{ {
return CreateOperationsImpl([](auto x) -> std::vector<Operation_Xdl_CShuffle> { return {x}; }); return CreateOperationsImpl([](auto x) -> std::vector<Operation_Xdl_CShuffle> { return {x}; },
Layout::Column,
Layout::Row);
} }
std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(const Problem& prob) std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(const Problem& prob)
{ {
...@@ -166,7 +170,9 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(con ...@@ -166,7 +170,9 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(con
x.tile_desc.n_per_block, x.tile_desc.n_per_block,
x.tile_desc.k_per_block); x.tile_desc.k_per_block);
return {x}; return {x};
}); },
ToLayout(prob.TransA),
ToLayout(prob.TransB));
} }
static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate = static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment