Commit 250a89f3 authored by Mirza Halilcevic's avatar Mirza Halilcevic
Browse files

Replace gemm_gemm with gemm_multiple_d_gemm_multiple_d.

parent d1e9682a
...@@ -8,13 +8,13 @@ ...@@ -8,13 +8,13 @@
#include <string> #include <string>
#include "ck/host/types.hpp" #include "ck/host/types.hpp"
#include "ck/host/operation/gemm.hpp" #include "ck/host/operation/gemm.hpp"
#include "ck/host/device_gemm_elementwise_gemm/problem.hpp" #include "ck/host/device_batched_gemm_multiple_d_gemm_multiple_d/problem.hpp"
namespace ck { namespace ck {
namespace host { namespace host {
namespace device_gemm_elementwise_gemm { namespace device_batched_gemm_multiple_d_gemm_multiple_d {
// defines all values need for an instance of fwd conv // defines all values needed for an instance
struct Operation_Xdl_CShuffle struct Operation_Xdl_CShuffle
{ {
// returns a vector of instances, only given fusion operators: will use default problem spec // returns a vector of instances, only given fusion operators: will use default problem spec
...@@ -23,36 +23,40 @@ struct Operation_Xdl_CShuffle ...@@ -23,36 +23,40 @@ struct Operation_Xdl_CShuffle
// returns a vector of instances, given a problem spec and fusion operators // returns a vector of instances, given a problem spec and fusion operators
static std::vector<Operation_Xdl_CShuffle> static std::vector<Operation_Xdl_CShuffle>
CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue); CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue);
TensorDesc A{}; TensorDesc A0{};
TensorDesc B0{}; TensorDesc B0{};
std::vector<TensorDesc> D0s = {};
TensorDesc B1{}; TensorDesc B1{};
TensorDesc C{}; std::vector<TensorDesc> D1s = {};
DataType acc = DataType::Float; TensorDesc E1{};
DataType cs_type = DataType::Half; DataType acc_type = DataType::Float;
std::string a_elem_op = PassThrough; DataType cshuffle_type = DataType::Float;
std::string b0_elem_op = PassThrough; std::string a0_elem_op = PassThrough;
std::string acc0_elem_op = PassThrough; std::string b0_elem_op = PassThrough;
std::string b1_elem_op = PassThrough; std::string cde0_elem_op = PassThrough;
std::string c_elem_op = PassThrough; std::string b1_elem_op = PassThrough;
std::string prologue = ""; std::string cde1_elem_op = PassThrough;
std::string epilogue = ""; std::string prologue = "";
std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default"; std::string epilogue = "";
// tuning parameters // tuning parameters
operation::TileDescGemmElementwiseGemm tile_desc{}; operation::PaddingDesc padding_desc{};
operation::BlockTransferDesc a_block_transfer{}; operation::TileDescGemmGemm tile_desc{};
operation::BlockTransferDesc a0_block_transfer{};
operation::BlockTransferDesc b0_block_transfer{}; operation::BlockTransferDesc b0_block_transfer{};
operation::BlockTransferDesc cde0_block_transfer{};
operation::BlockTransferDesc b1_block_transfer{}; operation::BlockTransferDesc b1_block_transfer{};
operation::CShuffleDesc cshuffle{}; operation::CShuffleDesc cshuffle{};
operation::CBlockTransferDesc c_block_transfer{}; operation::CBlockTransferDesc cde1_block_transfer{};
// functions to update fusion operators if provided // functions to update fusion operators if provided
void update_prologue(const std::string& prologue); void update_prologue(const std::string& prologue);
void update_epilogue(const std::string& epilogue); void update_epilogue(const std::string& epilogue);
/**constexpr**/ bool IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_); /**constexpr**/ bool
IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_, std::size_t Gemm1NRaw_);
// returns a templated instance // returns a templated instance
Solution ToSolution() const; Solution ToSolution() const;
}; };
} // namespace device_gemm_elementwise_gemm } // namespace device_batched_gemm_multiple_d_gemm_multiple_d
} // namespace host } // namespace host
} // namespace ck } // namespace ck
...@@ -10,28 +10,32 @@ ...@@ -10,28 +10,32 @@
namespace ck { namespace ck {
namespace host { namespace host {
namespace device_gemm_elementwise_gemm { namespace device_batched_gemm_multiple_d_gemm_multiple_d {
// defines the problem specification for a GEMM operation // defines the problem specification for a GEMM_ELEMENTWISE_GEMM operation
struct Problem struct Problem
{ {
std::size_t M = 0; std::size_t M = 0;
std::size_t N = 0; std::size_t N = 0;
std::size_t K = 0; std::size_t K = 0;
std::size_t O = 0; std::size_t O = 0;
bool TransA = false; bool TransA0 = false;
bool TransB0 = false; bool TransB0 = false;
bool TransB1 = false; std::vector<bool> D0sTrans = {};
bool TransC = false; bool TransB1 = false;
DataType ADataType = DataType::Half; std::vector<bool> D1sTrans = {};
DataType B0DataType = DataType::Half; bool TransE1 = false;
DataType B1DataType = DataType::Half; DataType A0DataType = DataType::Half;
DataType CDataType = DataType::Half; DataType B0DataType = DataType::Half;
std::string AElementOp = PassThrough; std::vector<DataType> D0sDataType = {};
std::string B0ElementOp = PassThrough; DataType B1DataType = DataType::Half;
std::string Acc0ElementOp = PassThrough; std::vector<DataType> D1sDataType = {};
std::string B1ElementOp = PassThrough; DataType E1DataType = DataType::Half;
std::string CElementOp = PassThrough; std::string A0ElementOp = PassThrough;
std::string B0ElementOp = PassThrough;
std::string CDE0ElementOp = PassThrough;
std::string B1ElementOp = PassThrough;
std::string CDE1ElementOp = PassThrough;
// returns the correct device op file for the operation // returns the correct device op file for the operation
std::string GetIncludeHeader() const; std::string GetIncludeHeader() const;
...@@ -42,6 +46,6 @@ struct Problem ...@@ -42,6 +46,6 @@ struct Problem
const std::string& epilogue) const; const std::string& epilogue) const;
}; };
} // namespace device_gemm_elementwise_gemm } // namespace device_batched_gemm_multiple_d_gemm_multiple_d
} // namespace host } // namespace host
} // namespace ck } // namespace ck
...@@ -9,6 +9,15 @@ namespace ck { ...@@ -9,6 +9,15 @@ namespace ck {
namespace host { namespace host {
namespace operation { namespace operation {
struct PaddingDesc
{
bool pad_gemm0_m = 0;
bool pad_gemm0_n = 0;
bool pad_gemm0_k = 0;
bool pad_gemm1_n = 0;
bool pad_gemm1_k = 0;
};
struct TileDesc struct TileDesc
{ {
int block_size = 0; int block_size = 0;
...@@ -24,23 +33,23 @@ struct TileDesc ...@@ -24,23 +33,23 @@ struct TileDesc
int num_gemmk_prefetch_stage = 0; int num_gemmk_prefetch_stage = 0;
}; };
struct TileDescGemmElementwiseGemm struct TileDescGemmGemm
{ {
int block_size = 0; int block_size = 0;
int gemm01_m_per_block = 0; int gemm0_m_per_block = 0;
int gemm0_n_per_block = 0; int gemm0_n_per_block = 0;
int gemm0_k_per_block = 0; int gemm0_k_per_block = 0;
int gemm1_n_per_block = 0; int gemm1_n_per_block = 0;
int gemm1_k_per_block = 0; int gemm1_k_per_block = 0;
int ak1 = 0; int a0k1 = 0;
int bk1 = 0; int b0k1 = 0;
int b1k1 = 0; int b1k1 = 0;
int m_per_XDL = 0; int m_per_XDL = 0;
int n_per_XDL = 0; int n_per_XDL = 0;
int gemm0_m_Xdl_per_wave = 0; int gemm0_m_Xdl_per_wave = 0;
int gemm0_n_Xdl_per_wave = 0; int gemm0_n_Xdl_per_wave = 0;
int gemm1_n_Xdl_per_wave = 0; int gemm1_n_Xdl_per_wave = 0;
int num_gemmk_prefetch_stage = 0; int num_gemm0k_prefetch_stage = 0;
}; };
struct BlockTransferDesc struct BlockTransferDesc
...@@ -53,11 +62,13 @@ struct BlockTransferDesc ...@@ -53,11 +62,13 @@ struct BlockTransferDesc
int dst_scalar_per_vector_k1 = 0; int dst_scalar_per_vector_k1 = 0;
int lds_add_extra_dim = 0; int lds_add_extra_dim = 0;
}; };
struct CShuffleDesc struct CShuffleDesc
{ {
int m_Xdl_per_wave_per_shuffle = 0; int m_Xdl_per_wave_per_shuffle = 0;
int n_Xdl_per_wave_per_shuffle = 0; int n_Xdl_per_wave_per_shuffle = 0;
}; };
struct CBlockTransferDesc struct CBlockTransferDesc
{ {
std::string cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl = ""; std::string cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl = "";
......
...@@ -2,19 +2,20 @@ ...@@ -2,19 +2,20 @@
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_gemm_elementwise_gemm/problem.hpp" #include "ck/host/device_batched_gemm_multiple_d_gemm_multiple_d/problem.hpp"
#include "ck/host/device_gemm_elementwise_gemm/operation.hpp" #include "ck/host/device_batched_gemm_multiple_d_gemm_multiple_d/operation.hpp"
#include "ck/host/utils.hpp" #include "ck/host/utils.hpp"
#include <algorithm> #include <algorithm>
namespace ck { namespace ck {
namespace host { namespace host {
namespace device_gemm_elementwise_gemm { namespace device_batched_gemm_multiple_d_gemm_multiple_d {
// return the relevant device op file based on the operation // return the relevant device op file based on the operation
std::string Problem::GetIncludeHeader() const std::string Problem::GetIncludeHeader() const
{ {
return "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"; return "ck/tensor_operation/gpu/device/impl/"
"device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp";
} }
// returns templated instances when provided with a problem specification // returns templated instances when provided with a problem specification
...@@ -24,8 +25,8 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch, ...@@ -24,8 +25,8 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch,
{ {
if(get_xdlop_archs().count(arch) == 0) if(get_xdlop_archs().count(arch) == 0)
return {}; return {};
auto ops = ck::host::device_gemm_elementwise_gemm::Operation_Xdl_CShuffle::CreateOperations( auto ops = ck::host::device_batched_gemm_multiple_d_gemm_multiple_d::Operation_Xdl_CShuffle::
*this, prologue, epilogue); // obtains vector of instances CreateOperations(*this, prologue, epilogue); // obtains vector of instances
std::vector<Solution> result; std::vector<Solution> result;
std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) { std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) {
return op.ToSolution(); // template instance with correct values return op.ToSolution(); // template instance with correct values
...@@ -33,6 +34,6 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch, ...@@ -33,6 +34,6 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch,
return result; return result;
} }
} // namespace device_gemm_elementwise_gemm } // namespace device_batched_gemm_multiple_d_gemm_multiple_d
} // namespace host } // namespace host
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_gemm_elementwise_gemm/operation.hpp" #include "ck/host/device_batched_gemm_multiple_d_gemm_multiple_d/operation.hpp"
#include "ck/host/stringutils.hpp" #include "ck/host/stringutils.hpp"
#include "ck/host/utils.hpp" #include "ck/host/utils.hpp"
#include <cassert> #include <cassert>
namespace ck { namespace ck {
namespace host { namespace host {
namespace device_gemm_elementwise_gemm { namespace device_batched_gemm_multiple_d_gemm_multiple_d {
// calculate appropriate Gemm Specification based on input tensor dimensions // calculate appropriate Gemm Specification based on input tensor dimensions
std::string GetGemmSpec(const std::size_t m, operation::PaddingDesc GetPaddingDesc(const std::size_t m,
const std::size_t n, const std::size_t n,
const std::size_t k, const std::size_t k,
const std::size_t n1, const std::size_t n1,
const std::size_t m_per_block, const std::size_t m_per_block,
const std::size_t n_per_block, const std::size_t n_per_block,
const std::size_t k_per_block, const std::size_t k_per_block,
const std::size_t n1_per_block) const std::size_t n1_per_block,
const std::size_t k1_per_block)
{ {
std::string spec = ""; operation::PaddingDesc desc;
if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0) if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0)
spec += "M"; desc.pad_gemm0_m = true;
if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0) if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0)
spec += "N"; desc.pad_gemm0_n = true;
if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0) if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0)
spec += "K"; desc.pad_gemm0_k = true;
if(integer_divide_ceil(n1, n1_per_block) * n1_per_block - n1 != 0) if(integer_divide_ceil(n1, n1_per_block) * n1_per_block - n1 != 0)
spec += "O"; desc.pad_gemm1_n = true;
if(spec == "") if(integer_divide_ceil(n, k1_per_block) * k1_per_block - n != 0) // TODO is n == k1 ?
return "ck::tensor_operation::device::GemmSpecialization::Default"; desc.pad_gemm1_k = true;
return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding"; return desc;
} }
// function to update prologue/epilogue with user provided operation // function to update prologue/epilogue with user provided operation
...@@ -41,6 +42,9 @@ void Operation_Xdl_CShuffle::update_prologue(const std::string& pro) ...@@ -41,6 +42,9 @@ void Operation_Xdl_CShuffle::update_prologue(const std::string& pro)
if(!prologue.empty()) if(!prologue.empty())
{ {
this->prologue = pro; this->prologue = pro;
// TODO is this right?
this->cde0_elem_op = "CDE0ElementOp";
this->cde1_elem_op = "CDE1ElementOp";
} }
else else
{ {
...@@ -53,6 +57,9 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi) ...@@ -53,6 +57,9 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
if(!epilogue.empty()) if(!epilogue.empty())
{ {
this->epilogue = epi; this->epilogue = epi;
// TODO is this right?
this->cde0_elem_op = "CDE0ElementOp";
this->cde1_elem_op = "CDE1ElementOp";
} }
else else
{ {
...@@ -68,24 +75,19 @@ static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row ...@@ -68,24 +75,19 @@ static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row
std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
const Problem& prob, const std::string& prologue, const std::string& epilogue) const Problem& prob, const std::string& prologue, const std::string& epilogue)
{ {
assert(prob.TransA == false); std::vector<Operation_Xdl_CShuffle> result;
assert(prob.TransB0 == true);
assert(prob.TransC == false);
const auto b1k1 = prob.TransB1 ? 4 : 2; const auto b1k1 = prob.TransB1 ? 4 : 2;
std::vector<Operation_Xdl_CShuffle> result; std::vector<operation::TileDescGemmGemm> tile_descriptions = {
std::vector<operation::TileDescGemmElementwiseGemm> tile_descriptions = {
// clang-format off // clang-format off
// Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| NumGemmK| // Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1|A0K1|B0K1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|NumGemm0K|
// Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Prefetch| // Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Prefetch|
// | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Stage| // | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Stage|
// | | | | | | | | | | | Wave| Wave| Wave| | // | | | | | | | | | | | Wave| Wave| Wave| |
{ 256, 256, 128, 32, 64, 32, 8, 8, b1k1, 32, 32, 2, 4, 2, 1}, //generic
{ 256, 256, 128, 32, 128, 32, 8, 8, b1k1, 32, 32, 2, 4, 4, 1}, { 256, 128, 64, 32, 128, 32, 8, 8, b1k1, 32, 32, 1, 2, 4, 1},
{ 256, 128, 256, 32, 64, 32, 8, 8, b1k1, 32, 32, 1, 8, 2, 1}, // no padding
{ 256, 128, 256, 32, 128, 32, 8, 8, b1k1, 32, 32, 1, 8, 4, 1},
{ 256, 128, 128, 64, 64, 32, 8, 8, b1k1, 32, 32, 1, 4, 2, 1}, { 256, 128, 128, 64, 64, 32, 8, 8, b1k1, 32, 32, 1, 4, 2, 1},
{ 256, 128, 128, 32, 64, 32, 8, 8, b1k1, 32, 32, 1, 4, 2, 1}, { 256, 128, 128, 32, 64, 32, 8, 8, b1k1, 32, 32, 1, 4, 2, 1},
{ 256, 128, 128, 64, 128, 32, 8, 8, b1k1, 32, 32, 1, 4, 4, 1}, { 256, 128, 128, 64, 128, 32, 8, 8, b1k1, 32, 32, 1, 4, 4, 1},
...@@ -94,22 +96,29 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -94,22 +96,29 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ 256, 64, 256, 32, 64, 32, 8, 8, b1k1, 16, 16, 1, 16, 4, 1}, { 256, 64, 256, 32, 64, 32, 8, 8, b1k1, 16, 16, 1, 16, 4, 1},
{ 256, 64, 256, 64, 128, 32, 8, 8, b1k1, 16, 16, 1, 16, 8, 1}, { 256, 64, 256, 64, 128, 32, 8, 8, b1k1, 16, 16, 1, 16, 8, 1},
{ 256, 64, 256, 64, 64, 32, 8, 8, b1k1, 16, 16, 1, 16, 4, 1}, { 256, 64, 256, 64, 64, 32, 8, 8, b1k1, 16, 16, 1, 16, 4, 1},
// Padded fallback kerne // Padded fallback kernel
{ 256, 128, 128, 64, 128, 32, 8, 8, b1k1, 32, 32, 1, 4, 4, 1}, { 256, 128, 128, 64, 128, 32, 8, 8, b1k1, 32, 32, 1, 4, 4, 1},
{ 256, 128, 64, 32, 128, 32, 8, 8, b1k1, 32, 32, 1, 2, 4, 1}, { 256, 128, 64, 32, 128, 32, 8, 8, b1k1, 32, 32, 1, 2, 4, 1},
// clang-format on // clang-format on
}; };
if(prob.TransB1)
{
// clang-format off
tile_descriptions.push_back(
{ 256, 256, 128, 32, 128, 32, 8, 8, 4, 32, 32, 2, 4, 4, 1}
);
// clang-format on
}
const std::vector<operation::BlockTransferDesc> a_block_descriptions = { std::vector<operation::BlockTransferDesc> a0_block_descriptions = {
// clang-format off // clang-format off
// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| // A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds|
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| // ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | // Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| |
// | | | | | | | // | | | | | | |
//generic
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, // no padding
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false}, { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false}, { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false},
...@@ -123,90 +132,111 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -123,90 +132,111 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
// clang-format on // clang-format on
}; };
if(prob.TransB1)
const auto& b0_block_descriptions_rowmajor = a_block_descriptions; {
const std::vector<operation::BlockTransferDesc> b0_block_descriptions_colmajor = {
// clang-format off // clang-format off
// B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| a0_block_descriptions.push_back(
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}
// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | );
// | | | | | | |
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
// clang-format on // clang-format on
}; }
auto b0_block_descriptions = a0_block_descriptions;
if(prob.TransB1)
{
b0_block_descriptions[1].lds_add_extra_dim = true;
b0_block_descriptions[3].lds_add_extra_dim = true;
}
const std::vector<operation::BlockTransferDesc> b1_block_descriptions_rowmajor = { std::vector<operation::BlockTransferDesc> cde0_block_descriptions = {
// clang-format off // clang-format off
// B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| // ... | CDE0BlockTransfer| CDE0BlockTransfer| ... |
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| // ... | SrcVectorDim| SrcScalar| ... |
// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | // ... | | PerVector| ... |
// | | | | | | | // | | | |
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, //generic
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, {"", "", "", 9, 1, 0, 0},
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, // no padding
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, {"", "", "", 9, 4, 0, 0},
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, {"", "", "", 9, 4, 0, 0},
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, {"", "", "", 9, 4, 0, 0},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, {"", "", "", 9, 4, 0, 0},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, {"", "", "", 9, 4, 0, 0},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, {"", "", "", 9, 4, 0, 0},
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, {"", "", "", 9, 4, 0, 0},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, {"", "", "", 9, 4, 0, 0},
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
// Padded fallback kernel // Padded fallback kernel
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, {"", "", "", 9, 4, 0, 0},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, {"", "", "", 9, 4, 0, 0},
// clang-format on // clang-format on
}; };
if(prob.TransB1)
const std::vector<operation::BlockTransferDesc> b1_block_descriptions_colmajor = { {
// clang-format off // clang-format off
// B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| cde0_block_descriptions.push_back(
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| {"", "", "", 9, 4, 0, 0}
// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | );
// | | | | | | |
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true},
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true},
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true},
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true},
// Padded fallback kernel
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true},
// clang-format on // clang-format on
}; }
const std::vector<operation::BlockTransferDesc> b1_block_descriptions_rowmajor =
{
// clang-format off
// B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds|
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | |
//generic
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
// no padding
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
// Padded fallback kernel
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
// clang-format on
};
const std::vector<operation::BlockTransferDesc> b1_block_descriptions_colmajor =
{
// clang-format off
// B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds|
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | |
//generic
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true},
// no padding
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true},
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true},
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true},
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true},
// Padded fallback kernel
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true},
// clang-format on
};
std::vector<operation::CShuffleDesc> cshuffle_descriptions = { std::vector<operation::CShuffleDesc> cshuffle_descriptions = {
// clang-format off // clang-format off
// CShuffle| CShuffle| // C1Shuffle| C1Shuffle|
// MXdlPerWave| NXdlPerWave| // MXdlPerWave| NXdlPerWave|
// PerShuffle| PerShuffle| // PerShuffle| PerShuffle|
// | | // | |
// generic
{ 1, 2}, { 1, 2},
{ 1, 2}, // no padding
{ 1, 2},
{ 1, 2},
{ 1, 2}, { 1, 2},
{ 1, 2}, { 1, 2},
{ 1, 2}, { 1, 2},
...@@ -220,69 +250,92 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -220,69 +250,92 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ 1, 2}, { 1, 2},
// clang-format on // clang-format on
}; };
if(prob.TransB1)
{
// clang-format off
cshuffle_descriptions.push_back(
{ 1, 2}
);
// clang-format on
}
std::vector<operation::CBlockTransferDesc> c_block_descriptions = { std::vector<operation::CBlockTransferDesc> cde1_block_descriptions = {
// clang-format off // clang-format off
// CBlockTransferClusterLengths| CBlockTransfer // CDE1BlockTransferClusterLengths| CDE1BlockTransfer|
// _MBlock_MWaveMPerXdl| ScalarPerVector // _MBlock_MWaveMPerXdl| ScalarPerVector|
// _NBlock_NWaveNPerXdl| _NWaveNPerXdl // _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// | // | |
{ S<1, 32, 1, 8>, 8}, // generic
{ S<1, 32, 1, 8>, 8}, { S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8}, // no padding
{ S<1, 32, 1, 8>, 8}, { S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8}, { S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8}, { S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8}, { S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8}, { S<1, 16, 1,16>, 8},
{ S<1, 16, 1,16>, 8}, { S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8}, { S<1, 16, 1,16>, 8},
{ S<1, 16, 1,16>, 8}, { S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
// Padded fallback kernel // Padded fallback kernel
{ S<1, 32, 1, 8>, 8}, { S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8}, { S<1, 32, 1, 8>, 8},
// clang-format on // clang-format on
}; };
if(prob.TransB1)
{
// clang-format off
cde1_block_descriptions.push_back(
{ S<1, 32, 1, 8>, 8}
);
// clang-format on
}
// choose correct arrangement of tuning parameters based on the layout of each tensor // choose correct arrangement of tuning parameters based on the layout of each tensor
const auto& b0_block_descriptions =
prob.TransB1 ? b0_block_descriptions_colmajor : b0_block_descriptions_rowmajor;
const auto& b1_block_descriptions = const auto& b1_block_descriptions =
prob.TransB1 ? b1_block_descriptions_colmajor : b1_block_descriptions_rowmajor; prob.TransB1 ? b1_block_descriptions_colmajor : b1_block_descriptions_rowmajor;
assert(tile_descriptions.size() == a_block_descriptions.size()); assert(tile_descriptions.size() == a0_block_descriptions.size());
assert(tile_descriptions.size() == b0_block_descriptions.size());
assert(tile_descriptions.size() == cde0_block_descriptions.size());
assert(tile_descriptions.size() == b1_block_descriptions.size()); assert(tile_descriptions.size() == b1_block_descriptions.size());
assert(tile_descriptions.size() == cshuffle_descriptions.size()); assert(tile_descriptions.size() == cshuffle_descriptions.size());
assert(tile_descriptions.size() == c_block_descriptions.size()); assert(tile_descriptions.size() == cde1_block_descriptions.size());
// Put all values together into a single operation > store into the result vector // Put all values together into a single operation > store into the result vector
for(std::size_t i = 0; i < tile_descriptions.size(); i++) for(std::size_t i = 0; i < tile_descriptions.size(); i++)
{ {
Operation_Xdl_CShuffle x; Operation_Xdl_CShuffle x;
x.tile_desc = tile_descriptions[i]; x.tile_desc = tile_descriptions[i];
x.a_block_transfer = a_block_descriptions[i]; x.a0_block_transfer = a0_block_descriptions[i];
x.b0_block_transfer = b0_block_descriptions[i]; x.b0_block_transfer = b0_block_descriptions[i];
x.cde0_block_transfer = cde0_block_descriptions[i];
x.b1_block_transfer = b1_block_descriptions[i]; x.b1_block_transfer = b1_block_descriptions[i];
x.cshuffle = cshuffle_descriptions[i]; x.cshuffle = cshuffle_descriptions[i];
x.c_block_transfer = c_block_descriptions[i]; x.cde1_block_transfer = cde1_block_descriptions[i];
x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)}; x.A0 = TensorDesc{prob.A0DataType, ToLayout(prob.TransA0)};
x.B0 = TensorDesc{prob.B0DataType, ToLayout(prob.TransB0)}; x.B0 = TensorDesc{prob.B0DataType, ToLayout(prob.TransB0)};
x.D0s = Transform(prob.D0sTrans, prob.D0sDataType, [](auto trans, auto dt) {
return TensorDesc{dt, ToLayout(trans)};
});
x.B1 = TensorDesc{prob.B1DataType, ToLayout(prob.TransB1)}; x.B1 = TensorDesc{prob.B1DataType, ToLayout(prob.TransB1)};
x.C = TensorDesc{prob.CDataType, ToLayout(prob.TransC)}; x.D1s = Transform(prob.D1sTrans, prob.D1sDataType, [](auto trans, auto dt) {
x.a_elem_op = prob.AElementOp; return TensorDesc{dt, ToLayout(trans)};
});
x.E1 = TensorDesc{prob.E1DataType, ToLayout(prob.TransE1)};
x.a0_elem_op = prob.A0ElementOp;
x.b0_elem_op = prob.B0ElementOp; x.b0_elem_op = prob.B0ElementOp;
x.cde0_elem_op = prob.CDE0ElementOp;
x.b1_elem_op = prob.B1ElementOp; x.b1_elem_op = prob.B1ElementOp;
x.c_elem_op = prob.CElementOp; x.cde1_elem_op = prob.CDE1ElementOp;
x.acc0_elem_op = prob.Acc0ElementOp; x.padding_desc = GetPaddingDesc(prob.M,
x.gemm_specialization = GetGemmSpec(prob.M, prob.N,
prob.N, prob.K,
prob.K, prob.O,
prob.O, x.tile_desc.gemm0_m_per_block,
x.tile_desc.gemm01_m_per_block, x.tile_desc.gemm0_n_per_block,
x.tile_desc.gemm0_n_per_block, x.tile_desc.gemm0_k_per_block,
x.tile_desc.gemm0_k_per_block, x.tile_desc.gemm1_n_per_block,
x.tile_desc.gemm1_n_per_block); x.tile_desc.gemm1_k_per_block);
x.update_prologue(prologue); x.update_prologue(prologue);
x.update_epilogue(epilogue); x.update_epilogue(epilogue);
result.push_back(x); result.push_back(x);
...@@ -298,10 +351,7 @@ Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std: ...@@ -298,10 +351,7 @@ Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std:
std::vector<std::vector<Operation_Xdl_CShuffle>> operations; std::vector<std::vector<Operation_Xdl_CShuffle>> operations;
Problem prob; Problem prob;
prob.TransA = false;
prob.TransB0 = true; prob.TransB0 = true;
prob.TransB1 = false;
prob.TransC = false;
operations.push_back(CreateOperations(prob, prologue, epilogue)); operations.push_back(CreateOperations(prob, prologue, epilogue));
prob.TransB1 = true; prob.TransB1 = true;
...@@ -310,29 +360,43 @@ Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std: ...@@ -310,29 +360,43 @@ Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std:
return operations; return operations;
} }
static const char* const DeviceBatchedGemmGemm_Xdl_CShuffleTemplate = static const char* const DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffleTemplate =
"ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<${LayoutA}, " "ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle<"
"${LayoutB0}, ${LayoutB1}, ${LayoutC}, ${ADataType}, ${B0DataType}, ${B1DataType}, " "${A0Layout}, ${B0Layout}, ${D0sLayout}, ${B1Layout}, ${D1sLayout}, ${E1Layout}, "
"${CDataType}, ${AccDataType}, ${CShuffleDataType}, ${AElementwiseOperation}, "
"${B0ElementwiseOperation}, ${Acc0ElementwiseOperation}, ${B1ElementwiseOperation}, " "${A0DataType}, ${B0DataType}, ${Acc0DataType}, ${D0sDataType}, ${B1DataType}, "
"${CElementwiseOperation}, ${GemmSpecialization}, ${NumGemmkPrefetchStage}, ${BlockSize}, " "${Acc1DataType}, ${C1ShuffleDataType}, ${D1sDataType}, ${E1DataType}, "
"${Gemm01MPerBlock}, ${Gemm0NPerBlock}, ${Gemm0KPerBlock}, ${Gemm1NPerBlock}, "
"${Gemm1KPerBlock}, ${AK1}, ${BK1}, ${B1K1}, ${MPerXDL}, ${NPerXDL}, ${Gemm0MXdlPerWave}, " "${A0ElementwiseOperation}, ${B0ElementwiseOperation}, ${CDE0ElementwiseOperation}, "
"${Gemm0NXdlPerWave}, ${Gemm1NXdlPerWave}, ${ABlockTransferThreadClusterLengths_AK0_M_AK1}, " "${B1ElementwiseOperation}, ${CDE1ElementwiseOperation}, "
"${ABlockTransferThreadClusterArrangeOrder}, ${ABlockTransferSrcAccessOrder}, "
"${ABlockTransferSrcVectorDim}, ${ABlockTransferSrcScalarPerVector}, " "${PadGemm0M}, ${PadGemm0N}, ${PadGemm0K}, ${PadGemm1N}, ${PadGemm1K}, "
"${ABlockTransferDstScalarPerVector_AK1}, ${ABlockLdsExtraM}, "
"${NumGemm0KPrefetchStage}, ${BlockSize}, ${Gemm0MPerBlock}, ${Gemm0NPerBlock}, "
"${Gemm0KPerBlock}, ${Gemm1NPerBlock}, ${Gemm1KPerBlock}, ${A0K1}, ${B0K1}, ${B1K1}, "
"${MPerXDL}, ${NPerXDL}, ${Gemm0MXdlPerWave}, ${Gemm0NXdlPerWave}, ${Gemm1NXdlPerWave}, "
"${A0BlockTransferThreadClusterLengths_AK0_M_AK1}, "
"${A0BlockTransferThreadClusterArrangeOrder}, ${A0BlockTransferSrcAccessOrder}, "
"${A0BlockTransferSrcVectorDim}, ${A0BlockTransferSrcScalarPerVector}, "
"${A0BlockTransferDstScalarPerVector_AK1}, ${A0BlockLdsExtraM}, "
"${B0BlockTransferThreadClusterLengths_BK0_N_BK1}, " "${B0BlockTransferThreadClusterLengths_BK0_N_BK1}, "
"${B0BlockTransferThreadClusterArrangeOrder}, ${B0BlockTransferSrcAccessOrder}, " "${B0BlockTransferThreadClusterArrangeOrder}, ${B0BlockTransferSrcAccessOrder}, "
"${B0BlockTransferSrcVectorDim}, ${B0BlockTransferSrcScalarPerVector}, " "${B0BlockTransferSrcVectorDim}, ${B0BlockTransferSrcScalarPerVector}, "
"${B0BlockTransferDstScalarPerVector_BK1}, ${B0BlockLdsExtraN}, " "${B0BlockTransferDstScalarPerVector_BK1}, ${B0BlockLdsExtraN}, "
"${CDE0BlockTransferSrcVectorDim}, ${CDE0BlockTransferSrcScalarPerVector}, "
"${B1BlockTransferThreadClusterLengths_BK0_N_BK1}, " "${B1BlockTransferThreadClusterLengths_BK0_N_BK1}, "
"${B1BlockTransferThreadClusterArrangeOrder}, ${B1BlockTransferSrcAccessOrder}, " "${B1BlockTransferThreadClusterArrangeOrder}, ${B1BlockTransferSrcAccessOrder}, "
"${B1BlockTransferSrcVectorDim}, ${B1BlockTransferSrcScalarPerVector}, " "${B1BlockTransferSrcVectorDim}, ${B1BlockTransferSrcScalarPerVector}, "
"${B1BlockTransferDstScalarPerVector_BK1}, ${B1BlockLdsExtraN}, " "${B1BlockTransferDstScalarPerVector_BK1}, ${B1BlockLdsExtraN}, "
"${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, "
"${CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl}, " "${C1ShuffleMXdlPerWavePerShuffle}, ${C1ShuffleGemm0NXdlPerWavePerShuffle}, "
"${CBlockTransferScalarPerVector_NWaveNPerXdl}>";
"${CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, "
"${CDE1ShuffleBlockTransferScalarPerVector_NPerBlock}>";
// use hardcoded instances from vector of operations to substitute values into instance template // use hardcoded instances from vector of operations to substitute values into instance template
Solution Operation_Xdl_CShuffle::ToSolution() const Solution Operation_Xdl_CShuffle::ToSolution() const
...@@ -340,60 +404,80 @@ Solution Operation_Xdl_CShuffle::ToSolution() const ...@@ -340,60 +404,80 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
std::unordered_map<std::string, std::string> values = { std::unordered_map<std::string, std::string> values = {
{"name", {"name",
std::to_string(this->tile_desc.block_size) + "_" + std::to_string(this->tile_desc.block_size) + "_" +
std::to_string(this->tile_desc.gemm01_m_per_block) + "_" + std::to_string(this->tile_desc.gemm0_m_per_block) + "_" +
std::to_string(this->tile_desc.gemm0_n_per_block) + "_" + std::to_string(this->tile_desc.gemm0_n_per_block) + "_" +
std::to_string(this->tile_desc.gemm0_k_per_block) + "_" + std::to_string(this->tile_desc.gemm0_k_per_block) + "_" +
std::to_string(this->tile_desc.gemm1_n_per_block) + "_" + std::to_string(this->tile_desc.gemm1_n_per_block) + "_" +
std::to_string(this->tile_desc.gemm1_k_per_block) + "_" + std::to_string(this->tile_desc.gemm1_k_per_block) + "_" +
std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" + std::to_string(this->tile_desc.a0k1) + "_" + std::to_string(this->tile_desc.b0k1) +
std::to_string(this->tile_desc.b1k1) + "_" + "_" + std::to_string(this->tile_desc.b1k1) + "_" +
std::to_string(this->tile_desc.m_per_XDL) + "_" + std::to_string(this->tile_desc.m_per_XDL) + "_" +
std::to_string(this->tile_desc.n_per_XDL) + "_" + std::to_string(this->tile_desc.n_per_XDL) + "_" +
std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave) + "_" + std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave) + "_" +
std::to_string(this->tile_desc.gemm0_n_Xdl_per_wave) + "_" + std::to_string(this->tile_desc.gemm0_n_Xdl_per_wave) + "_" +
std::to_string(this->tile_desc.gemm1_n_Xdl_per_wave)}, std::to_string(this->tile_desc.gemm1_n_Xdl_per_wave)},
{"LayoutA", ToString(this->A.layout)},
{"LayoutB0", ToString(this->B0.layout)}, {"A0Layout", ToString(this->A0.layout)},
{"LayoutB1", ToString(this->B1.layout)}, {"B0Layout", ToString(this->B0.layout)},
{"LayoutC", ToString(this->C.layout)}, {"D0sLayout",
{"ADataType", ToString(this->A.element)}, MakeTuple(Transform(this->D0s, [](auto tensor) { return ToString(tensor.layout); }))},
{"B1Layout", ToString(this->B1.layout)},
{"D1sLayout",
MakeTuple(Transform(this->D1s, [](auto tensor) { return ToString(tensor.layout); }))},
{"E1Layout", ToString(this->E1.layout)},
{"ADataType", ToString(this->A0.element)},
{"B0DataType", ToString(this->B0.element)}, {"B0DataType", ToString(this->B0.element)},
{"Acc0DataType", ToString(this->acc_type)},
{"D0sDataType",
MakeTuple(Transform(this->D0s, [](auto tensor) { return ToString(tensor.element); }))},
{"B1DataType", ToString(this->B1.element)}, {"B1DataType", ToString(this->B1.element)},
{"CDataType", ToString(this->C.element)}, {"Acc1DataType", ToString(this->acc_type)},
{"AccDataType", ToString(this->acc)}, {"C1ShuffleDataType", ToString(this->cshuffle_type)},
{"CShuffleDataType", ToString(this->cs_type)}, {"D1sDataType",
{"AElementwiseOperation", this->a_elem_op}, MakeTuple(Transform(this->D1s, [](auto tensor) { return ToString(tensor.element); }))},
{"E1DataType", ToString(this->E1.element)},
{"A0ElementwiseOperation", this->a0_elem_op},
{"B0ElementwiseOperation", this->b0_elem_op}, {"B0ElementwiseOperation", this->b0_elem_op},
{"Acc0ElementwiseOperation", this->acc0_elem_op}, {"CDE0ElementwiseOperation", this->cde0_elem_op},
{"B1ElementwiseOperation", this->b1_elem_op}, {"B1ElementwiseOperation", this->b1_elem_op},
{"CElementwiseOperation", this->c_elem_op}, {"CDE1ElementwiseOperation", this->cde1_elem_op},
{"GemmSpecialization", this->gemm_specialization},
{"NumGemmkPrefetchStage", std::to_string(this->tile_desc.num_gemmk_prefetch_stage)}, {"PadGemm0M", std::to_string(this->padding_desc.pad_gemm0_m)},
{"PadGemm0N", std::to_string(this->padding_desc.pad_gemm0_n)},
{"PadGemm0K", std::to_string(this->padding_desc.pad_gemm0_k)},
{"PadGemm1N", std::to_string(this->padding_desc.pad_gemm1_n)},
{"PadGemm1K", std::to_string(this->padding_desc.pad_gemm1_k)},
{"NumGemm0KPrefetchStage", std::to_string(this->tile_desc.num_gemm0k_prefetch_stage)},
{"BlockSize", std::to_string(this->tile_desc.block_size)}, {"BlockSize", std::to_string(this->tile_desc.block_size)},
{"Gemm01MPerBlock", std::to_string(this->tile_desc.gemm01_m_per_block)}, {"Gemm0MPerBlock", std::to_string(this->tile_desc.gemm0_m_per_block)},
{"Gemm0NPerBlock", std::to_string(this->tile_desc.gemm0_n_per_block)}, {"Gemm0NPerBlock", std::to_string(this->tile_desc.gemm0_n_per_block)},
{"Gemm0KPerBlock", std::to_string(this->tile_desc.gemm0_k_per_block)}, {"Gemm0KPerBlock", std::to_string(this->tile_desc.gemm0_k_per_block)},
{"Gemm1NPerBlock", std::to_string(this->tile_desc.gemm1_n_per_block)}, {"Gemm1NPerBlock", std::to_string(this->tile_desc.gemm1_n_per_block)},
{"Gemm1KPerBlock", std::to_string(this->tile_desc.gemm1_k_per_block)}, {"Gemm1KPerBlock", std::to_string(this->tile_desc.gemm1_k_per_block)},
{"AK1", std::to_string(this->tile_desc.ak1)}, {"A0K1", std::to_string(this->tile_desc.a0k1)},
{"BK1", std::to_string(this->tile_desc.bk1)}, {"B0K1", std::to_string(this->tile_desc.b0k1)},
{"B1K1", std::to_string(this->tile_desc.b1k1)}, {"B1K1", std::to_string(this->tile_desc.b1k1)},
{"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)}, {"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)},
{"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)}, {"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)},
{"Gemm0MXdlPerWave", std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave)}, {"Gemm0MXdlPerWave", std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave)},
{"Gemm0NXdlPerWave", std::to_string(this->tile_desc.gemm0_n_Xdl_per_wave)}, {"Gemm0NXdlPerWave", std::to_string(this->tile_desc.gemm0_n_Xdl_per_wave)},
{"Gemm1NXdlPerWave", std::to_string(this->tile_desc.gemm1_n_Xdl_per_wave)}, {"Gemm1NXdlPerWave", std::to_string(this->tile_desc.gemm1_n_Xdl_per_wave)},
{"ABlockTransferThreadClusterLengths_AK0_M_AK1",
this->a_block_transfer.thread_cluster_length}, {"A0BlockTransferThreadClusterLengths_AK0_M_AK1",
{"ABlockTransferThreadClusterArrangeOrder", this->a0_block_transfer.thread_cluster_length},
this->a_block_transfer.thread_cluster_arrange_order}, {"A0BlockTransferThreadClusterArrangeOrder",
{"ABlockTransferSrcAccessOrder", this->a_block_transfer.src_access_order}, this->a0_block_transfer.thread_cluster_arrange_order},
{"ABlockTransferSrcVectorDim", std::to_string(this->a_block_transfer.src_vec_dim)}, {"A0BlockTransferSrcAccessOrder", this->a0_block_transfer.src_access_order},
{"ABlockTransferSrcScalarPerVector", {"A0BlockTransferSrcVectorDim", std::to_string(this->a0_block_transfer.src_vec_dim)},
std::to_string(this->a_block_transfer.src_scalar_per_vector)}, {"A0BlockTransferSrcScalarPerVector",
{"ABlockTransferDstScalarPerVector_AK1", std::to_string(this->a0_block_transfer.src_scalar_per_vector)},
std::to_string(this->a_block_transfer.dst_scalar_per_vector_k1)}, {"A0BlockTransferDstScalarPerVector_AK1",
{"ABlockLdsExtraM", std::to_string(this->a_block_transfer.lds_add_extra_dim)}, std::to_string(this->a0_block_transfer.dst_scalar_per_vector_k1)},
{"A0BlockLdsExtraM", std::to_string(this->a0_block_transfer.lds_add_extra_dim)},
{"B0BlockTransferThreadClusterLengths_BK0_N_BK1", {"B0BlockTransferThreadClusterLengths_BK0_N_BK1",
this->b0_block_transfer.thread_cluster_length}, this->b0_block_transfer.thread_cluster_length},
{"B0BlockTransferThreadClusterArrangeOrder", {"B0BlockTransferThreadClusterArrangeOrder",
...@@ -405,6 +489,11 @@ Solution Operation_Xdl_CShuffle::ToSolution() const ...@@ -405,6 +489,11 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
{"B0BlockTransferDstScalarPerVector_BK1", {"B0BlockTransferDstScalarPerVector_BK1",
std::to_string(this->b0_block_transfer.dst_scalar_per_vector_k1)}, std::to_string(this->b0_block_transfer.dst_scalar_per_vector_k1)},
{"B0BlockLdsExtraN", std::to_string(this->b0_block_transfer.lds_add_extra_dim)}, {"B0BlockLdsExtraN", std::to_string(this->b0_block_transfer.lds_add_extra_dim)},
{"CDE0BlockTransferSrcVectorDim", std::to_string(this->cde0_block_transfer.src_vec_dim)},
{"CDE0BlockTransferSrcScalarPerVector",
std::to_string(this->cde0_block_transfer.src_scalar_per_vector)},
{"B1BlockTransferThreadClusterLengths_BK0_N_BK1", {"B1BlockTransferThreadClusterLengths_BK0_N_BK1",
this->b1_block_transfer.thread_cluster_length}, this->b1_block_transfer.thread_cluster_length},
{"B1BlockTransferThreadClusterArrangeOrder", {"B1BlockTransferThreadClusterArrangeOrder",
...@@ -416,20 +505,24 @@ Solution Operation_Xdl_CShuffle::ToSolution() const ...@@ -416,20 +505,24 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
{"B1BlockTransferDstScalarPerVector_BK1", {"B1BlockTransferDstScalarPerVector_BK1",
std::to_string(this->b1_block_transfer.dst_scalar_per_vector_k1)}, std::to_string(this->b1_block_transfer.dst_scalar_per_vector_k1)},
{"B1BlockLdsExtraN", std::to_string(this->b1_block_transfer.lds_add_extra_dim)}, {"B1BlockLdsExtraN", std::to_string(this->b1_block_transfer.lds_add_extra_dim)},
{"CShuffleMXdlPerWavePerShuffle",
{"C1ShuffleMXdlPerWavePerShuffle",
std::to_string(this->cshuffle.m_Xdl_per_wave_per_shuffle)}, std::to_string(this->cshuffle.m_Xdl_per_wave_per_shuffle)},
{"CShuffleNXdlPerWavePerShuffle", {"C1ShuffleGemm0NXdlPerWavePerShuffle",
std::to_string(this->cshuffle.n_Xdl_per_wave_per_shuffle)}, std::to_string(this->cshuffle.n_Xdl_per_wave_per_shuffle)},
{"CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl",
this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, {"CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock",
{"CBlockTransferScalarPerVector_NWaveNPerXdl", this->cde1_block_transfer
std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)}, .cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl},
{"CDE1ShuffleBlockTransferScalarPerVector_NPerBlock",
std::to_string(this->cde1_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)},
}; };
return Solution{InterpolateString(DeviceBatchedGemmGemm_Xdl_CShuffleTemplate, values), return Solution{
std::move(values)}; InterpolateString(DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffleTemplate, values),
std::move(values)};
} }
} // namespace device_gemm_elementwise_gemm } // namespace device_batched_gemm_multiple_d_gemm_multiple_d
} // namespace host } // namespace host
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment